In [1]:
from PIL import Image
import os
import numpy as np
import torch
import torch.utils.data
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms as T
import json
import shutil

We start by defining the epochs and classes for our training script. We have 10 total fruits we wish to identify in the images and will start by running the training for 15 epochs. After you are done with the initial training, feel free to change the number of epochs to see how it changes your loss function value.

In [2]:
epochs = 15
num_classes = 10

In [3]:
data_dir = "/home/ubuntu/notebooks/data/fruit_data"

In [4]:
output_file = "/home/ubuntu/notebooks/my_model.pth"

In [5]:
!nvidia-smi

Thu Dec 26 03:54:50 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A10G                    On  | 00000000:00:1E.0 Off |                    0 |
|  0%   23C    P0              56W / 300W |    780MiB / 23028MiB |      3%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [7]:
class FruitDataset(torch.utils.data.Dataset):
    # This function is run once when instantiating the Dataset object
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms

        # In the first portion of this code we are taking our single dataset folder 
        # and splitting it into three folders based on the file types.
        # This is just a preprocessing step.
        list_ = os.listdir(root)
        for file_ in list_:
            name, ext = os.path.splitext(file_)
            ext = ext[1:]
            if ext == '':
                continue

            if os.path.exists(root+ '/' + ext):
                shutil.move(root+'/'+file_, root+'/'+ext+'/'+file_)

            else:
                os.makedirs(root+'/'+ext)
                shutil.move(root+'/'+file_, root+'/'+ext+'/'+file_)

        self.imgs = list(sorted(os.listdir(os.path.join(root, "png"))))
        self.label = list(sorted(os.listdir(os.path.join(root, "json"))))
        self.box = list(sorted(os.listdir(os.path.join(root, "npy"))))
        # We have our three attributes with the img, label, and box data

    # Loads and returns a sample from the dataset at the given index idx
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "png", self.imgs[idx])
        img = Image.open(img_path).convert("RGB")

        label_path = os.path.join(self.root, "json", self.label[idx])

        with open(os.path.join('root', label_path), "r") as json_data:
            json_labels = json.load(json_data)
        
        box_path = os.path.join(self.root, "npy", self.box[idx])
        dat = np.load(str(box_path))   

        boxes = []
        labels = []
        for i in dat:
            obj_val = i[0]
            xmin = torch.as_tensor(np.min(i[1]), dtype=torch.float32)
            xmax = torch.as_tensor(np.max(i[3]), dtype=torch.float32)
            ymin = torch.as_tensor(np.min(i[2]), dtype=torch.float32)
            ymax = torch.as_tensor(np.max(i[4]), dtype=torch.float32)
            if (ymax > ymin) & (xmax > xmin):
                boxes.append([xmin, ymin, xmax, ymax])
                area = (xmax - xmin) * (ymax - ymin)
            labels += [json_labels.get(str(obj_val)).get('class')]

        label_dict = {}

        # Labels for the dataset
        static_labels = {
            'apple' : 0,
            'avocado' : 1,
            'kiwi' : 2,
            'lime' : 3,
            'lychee' : 4,
            'pomegranate' : 5,
            'onion' : 6,
            'strawberry' : 7,
            'lemon' : 8,
            'orange' : 9
        }

        labels_out = []
        # Transforming the input labels into a static label dictionary to use
        for i in range(len(labels)):
            label_dict[i] = labels[i]

        for i in label_dict:
            fruit = label_dict[i]
            final_fruit_label = static_labels[fruit]
            labels_out += [final_fruit_label]

        target = {}
        target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
        target["labels"] = torch.as_tensor(labels_out, dtype=torch.int64)
        target["image_id"] = torch.tensor([idx]) 
        target["area"] = area

        if self.transforms is not None:
            img= self.transforms(img)
        return img, target

    # Finally we have a function for the number of samples in our dataset
    def __len__(self):
        return len(self.imgs)

In [8]:
def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    return T.Compose(transforms)

Create a function to collate our samples. 

In [9]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [10]:
def create_model(num_classes): 
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 
    return model

In [11]:
dataset = FruitDataset(data_dir, get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=16, shuffle=True, collate_fn= collate_fn) 

Next, we create our model with the 10 classes we have of fruit and transfer it to the GPU for training. We use [PyTorch SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) (stochastic gradient descent) as the optimizer.

In [12]:
model = create_model(num_classes)
model.to(device)
    
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.001)
len_dataloader = len(data_loader)

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100.0%


In [13]:
model.train()
ep = 0
for epoch in range(epochs):
    optimizer.zero_grad()
    ep += 1
    i = 0    
    for imgs, annotations in data_loader:
        i += 1
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())

        losses.backward()
        optimizer.step()

        print(f'Epoch: {ep} Iteration: {i}/{len_dataloader}, Loss: {losses}')

Epoch: 1 Iteration: 1/7, Loss: 3.978541374206543
Epoch: 1 Iteration: 2/7, Loss: 3.3974666595458984
Epoch: 1 Iteration: 3/7, Loss: 3.380915641784668
Epoch: 1 Iteration: 4/7, Loss: 2.351797103881836
Epoch: 1 Iteration: 5/7, Loss: 1.9960970878601074
Epoch: 1 Iteration: 6/7, Loss: 2.01206111907959
Epoch: 1 Iteration: 7/7, Loss: 1.975547432899475
Epoch: 2 Iteration: 1/7, Loss: 2.0979387760162354
Epoch: 2 Iteration: 2/7, Loss: 2.014479398727417
Epoch: 2 Iteration: 3/7, Loss: 1.9199408292770386
Epoch: 2 Iteration: 4/7, Loss: 1.8923652172088623
Epoch: 2 Iteration: 5/7, Loss: 1.768477439880371
Epoch: 2 Iteration: 6/7, Loss: 1.9157696962356567
Epoch: 2 Iteration: 7/7, Loss: 1.7139626741409302
Epoch: 3 Iteration: 1/7, Loss: 1.8533666133880615
Epoch: 3 Iteration: 2/7, Loss: 1.8163641691207886
Epoch: 3 Iteration: 3/7, Loss: 2.0194742679595947
Epoch: 3 Iteration: 4/7, Loss: 1.789918065071106
Epoch: 3 Iteration: 5/7, Loss: 1.7051714658737183
Epoch: 3 Iteration: 6/7, Loss: 1.7313717603683472
Epoch: 3 

In [14]:
torch.save(model, output_file)