In [None]:
import super_gradients

In [None]:
import pycocotools

In [None]:
!pip install Cmake

In [None]:
!pip install cython 

In [None]:
!pip install git+https://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI

In [None]:
!pip install super-gradients

In [None]:
from super_gradients.training import models
from super_gradients.common.object_names import Models

In [None]:
model = models.get(Models.YOLO_NAS_S, num_classes=3)

In [None]:
import torch
from torch.utils.data import Dataset
import json
import os
from PIL import Image
from torchvision import transforms, utils
import numpy as np
import glob
import random


class CustomDataset(Dataset):
    """
    A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
    """

    def __init__(self, data_folder, split, keep_difficult=False):
        """
        :param data_folder: folder where data files are stored
        :param split: split, one of 'TRAIN' or 'TEST'
        :param keep_difficult: keep or discard objects that are considered difficult to detect?
        """
        self.split = split.lower()

        assert self.split in {'train', 'test'}

        self.data_folder = data_folder
        self.keep_difficult = keep_difficult

        # Read data files
        # with open(os.path.join(data_folder, self.split + '.txt'), 'r') as j:
        #     self.images = j.readlines()
        self.images = glob.glob(data_folder+"images/"+self.split+"/"+"*.jpg")
        self.images = [i.replace("\\","/") for i in self.images]
        self.images = random.sample(self.images,20)

    def __getitem__(self, i):
        # Read image and label
        image = Image.open(self.images[i].replace("\n","").replace("\\","/"), mode='r').resize((320, 320))
        image = image.convert("RGB")
        # image_tensor = torch.tensor(np.expand_dims(image,axis=-1)).permute(2, 0, 1).float()
        image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
        labels = np.loadtxt(self.images[i].replace("jpg","txt").replace("images", "labels"))
        if len(labels.shape) == 1:
            labels = np.transpose(np.expand_dims(labels,axis=-1))
        return image_tensor, torch.tensor(labels,dtype=torch.float)
        

    def __len__(self):
        return len(self.images)

In [None]:
train_dataset = CustomDataset("dataset_deteccao/axial_t1wce_2_class_corrigida/",split="train")
val_dataset = CustomDataset("dataset_deteccao/axial_t1wce_2_class_corrigida/",split="test")

In [None]:
from torch.utils.data import Dataset, DataLoader
from super_gradients.training.utils.collate_fn.detection_collate_fn import DetectionCollateFN

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0, collate_fn=DetectionCollateFN())
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=DetectionCollateFN())

In [None]:
from super_gradients.training import training_hyperparams

from prettyformatter import pprint

train_params = training_hyperparams.get('coco2017_yolo_nas_s')
print('Training parameters:')
pprint(train_params, json=True)

In [None]:
train_params['max_epochs'] = 1
# train_params['lr_warmup_epochs'] = 0
# train_params['lr_cooldown_epochs'] = 0
train_params['criterion_params']['num_classes'] = 3
# train_params['average_best_models'] = False
# train_params['initial_lr'] = 0.01
# train_params['cosine_final_lr_ratio'] = 0.9
# train_params['mixed_precision'] = False
# train_params['phase_callbacks'] = []
# train_params['lr_warmup_steps'] = 10
# train_params['valid_metrics_list'] = [{"DetectionMetrics": {"post_prediction_callback": super_gradients.training.utils.ssd_utils.SSDPostPredictCallback(), "num_cls": 3}}]

In [None]:
pprint(train_params, json=True)

In [None]:
from super_gradients.training import Trainer, MultiGPUMode

In [None]:
CHECKPOINT_DIR = '.'

In [None]:
super_gradients.setup_device(num_gpus=0)

In [None]:
from super_gradients.training.processing import ImagePermute, ComposeProcessing

image_processor = ComposeProcessing(
    [
        # Resize(320),
        ImagePermute(permutation=(2, 0, 1)),
    ]
)


In [None]:
model.set_dataset_processing_params(iou=0.5,class_names=['negative','positive'],conf=0.5,image_processor=image_processor)

In [None]:
trainer = Trainer(experiment_name='transfer_learning_object_detection_yolo_nas_s', ckpt_root_dir=CHECKPOINT_DIR)

In [None]:
trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=val_dataloader)

In [None]:
val_dataset.__getitem__(0)[0]

In [None]:
predicoes = model.predict(val_dataset.__getitem__(0)[0])

In [None]:
predicoes.show()