## Finetune pretrained model with your dataset 

Synthetic data is agnostic to the training framework, we are showing one way in which it can be used with Torchvision. However, when you create your own custom dataset you can plug that data into your own training workflow. 

In [None]:
from PIL import Image
import os
import numpy as np
import torch
import torch.utils.data
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms as T
import json
import shutil

We start by defining the epochs and classes for our training script. We have 10 total fruits we wish to identify in the images and we will start with running for 15 epochs. We can come back and adjust the epochs as we iterate on our training script. 

In [None]:
epochs = 15
num_classes = 10

We can navigate to our data that we generated by opening up a terminal in our Jupyter window. Press the "+" button in the top left to access a new terminal. Our data generation script defaults to `/dli/task/data/fruit_data_$DATE`. For now we have the data directory set to an example dataset you may choose to use. 

In [None]:
data_dir = "/dli/task/data/fruit_data"

Next, we define our output dirtectory which is where we will save our PyTorch model. We also have an example model saved to `/dli/task/data/model.pth`.

In [None]:
output_file = "/dli/task/model.pth"

In our system today we are using a single A10 GPU. This gives us a powerful compute engine for training and state of the art tech for our graphics applications as well.

In [None]:
!nvidia-smi

We are defining our device for the training to make sure we are able to use the A10 we have available.

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Below we will define our `FruitDataset` class and the dataloader for the training. We have added comments throughout the code to explain each step.

In [None]:
class FruitDataset(torch.utils.data.Dataset):
    # This function is run once when instantiating the Dataset object
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms

        # In the first portion of this code we are taking our single dataset folder 
        # and splitting it into three folders based on the file types.
        # This is just a preprocessing step.
        list_ = os.listdir(root)
        for file_ in list_:
            name, ext = os.path.splitext(file_)
            ext = ext[1:]
            if ext == '':
                continue

            if os.path.exists(root+ '/' + ext):
                shutil.move(root+'/'+file_, root+'/'+ext+'/'+file_)

            else:
                os.makedirs(root+'/'+ext)
                shutil.move(root+'/'+file_, root+'/'+ext+'/'+file_)

        self.imgs = list(sorted(os.listdir(os.path.join(root, "png"))))
        self.label = list(sorted(os.listdir(os.path.join(root, "json"))))
        self.box = list(sorted(os.listdir(os.path.join(root, "npy"))))
        # We have our three attributes with the img, label, and box data

    # Loads and returns a sample from the dataset at the given index idx
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "png", self.imgs[idx])
        img = Image.open(img_path).convert("RGB")

        label_path = os.path.join(self.root, "json", self.label[idx])

        with open(os.path.join('root', label_path), "r") as json_data:
            json_labels = json.load(json_data)
        
        box_path = os.path.join(self.root, "npy", self.box[idx])
        dat = np.load(str(box_path))   

        boxes = []
        labels = []
        for i in dat:
            obj_val = i[0]
            xmin = torch.as_tensor(np.min(i[1]), dtype=torch.float32)
            xmax = torch.as_tensor(np.max(i[3]), dtype=torch.float32)
            ymin = torch.as_tensor(np.min(i[2]), dtype=torch.float32)
            ymax = torch.as_tensor(np.max(i[4]), dtype=torch.float32)
            if (ymax > ymin) & (xmax > xmin):
                boxes.append([xmin, ymin, xmax, ymax])
                area = (xmax - xmin) * (ymax - ymin)
            labels += [json_labels.get(str(obj_val)).get('class')]

        label_dict = {}

        # Labels for the dataset
        static_labels = {
            'apple' : 0,
            'avocado' : 1,
            'kiwi' : 2,
            'lime' : 3,
            'lychee' : 4,
            'pomegranate' : 5,
            'onion' : 6,
            'strawberry' : 7,
            'lemon' : 8,
            'orange' : 9
        }

        labels_out = []
        # Transforming the input labels into a static label dictionary to use
        for i in range(len(labels)):
            label_dict[i] = labels[i]

        for i in label_dict:
            fruit = label_dict[i]
            final_fruit_label = static_labels[fruit]
            labels_out += [final_fruit_label]

        target = {}
        target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
        target["labels"] = torch.as_tensor(labels_out, dtype=torch.int64)
        target["image_id"] = torch.tensor([idx]) 
        target["area"] = area

        if self.transforms is not None:
            img= self.transforms(img)
        return img, target

    # Finally we have a function for the number of samples in our dataset
    def __len__(self):
        return len(self.imgs)

Next we are defining our function for the feature and label transformations we wish to preform. We are converting to Tenso and also converting the dtypes.

In [None]:
def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    return T.Compose(transforms)

Create a function to collate our samples 

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

Next we go through the process of actually creating our model. We are using the `fasterrcnn_resnet50` pretrained model from torchvision in this example. 

In [None]:
def create_model(num_classes): 
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 
    return model

At this point we are ready to create our dataset by using our custom `FruitDataset` class and our synthetic  data. This is then passed into our DataLoader.

In [None]:
dataset = FruitDataset(data_dir, get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=16, shuffle=True, collate_fn= collate_fn) 

Next we create our model with the 10 classes we have of fruit and transfer to the GPU for training. We use torch SDG for stochastic gradient descent optimizer.

In [None]:
model = create_model(num_classes)
model.to(device)
    
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.001)
len_dataloader = len(data_loader)

In our final section we actually train our model. We keep track of our loss and print out as we train.

In [None]:
model.train()
for epoch in range(epochs):
    optimizer.zero_grad()

    i = 0    
    for imgs, annotations in data_loader:
        i += 1
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())

        losses.backward()
        optimizer.step()

        print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')

Our final step is to save the model!

In [None]:
torch.save(model, output_file)