In [28]:
import torch
import torchvision
from torchvision import transforms as T
import torch.optim as optim
import cv2
import time
from torch.utils.data import Dataset, DataLoader
import os
import xml.etree.ElementTree as ET
from PIL import Image
from torchvision.transforms import v2
from torchvision import tv_tensors
import matplotlib.pyplot as plt
from torchvision.io import read_file, read_image
from typing import Union
import random
from collections import defaultdict, Counter
import numpy as np
from tqdm.notebook import tqdm, tnrange

In [2]:
def parse_xml(xml_file: str):

    tree = ET.parse(xml_file)
    root = tree.getroot()

    labels = []
    bboxes = []

    for boxes in root.iter('object'):

        label = boxes.find('name').text
        labels.append(label)

        ymin, xmin, ymax, xmax = None, None, None, None

        ymin = int(boxes.find("bndbox/ymin").text)
        xmin = int(boxes.find("bndbox/xmin").text)
        ymax = int(boxes.find("bndbox/ymax").text)
        xmax = int(boxes.find("bndbox/xmax").text)

        bbox = [xmin, ymin, xmax, ymax]
        bboxes.append(bbox)

    return labels, bboxes

In [3]:
class LabelTransform():
    def __init__(self, classes : list):
        assert len(classes)>0, "Number of classes should not be empty"
        self.labels = ["__background__"] + classes
        self.index2labels = {l:idx for idx,l in enumerate(self.labels)}

    def __len__(self):
        return len(self.labels)-1

    def __call__(self, labels: list):
        return [self.index2labels[l] for l in labels]

In [64]:
class SSDTransform():
    '''
    Custom transformation class that does one of the following tranformation in uniform distribution:
        - No transformation
        - ColorJitter
        - HorizontalFlip
        - Rotation (clockwise 90 degrees)
        - Rotation (Anti clockwise 90 degrees)
    '''

    TRANSFORMS = ["horizontal_flip", "color_jitter", "clockwise_rotate", "anitclockwise_rotate"]
    
    def __init__(self, training: bool = False, hflip : float = False, rotate_1 : Union[int, tuple] = None, rotate_2: Union[int, tuple] = None):
        self.training = training
        self.color_jitter = v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)
        self.hflip = hflip
        if hflip:
            self.horizontal_flip = v2.RandomHorizontalFlip(p=1.0)
        self.clockwise_rotate = rotate_1
        if rotate_1 is not None:
            self.random_rotate_clockwise = v2.RandomRotation(rotate_1)
        self.anticlockwise_rotate = rotate_2
        if rotate_2 is not None:
            self.random_rotate_anticlockwise = v2.RandomRotation(rotate_2)

        self.intervals = list(np.arange(len(self.TRANSFORMS)+1) * (1/(len(self.TRANSFORMS)+1))) + [1.0]

        self.default_transforms = v2.Compose(
            [v2.ToImage(), v2.ToDtype(dtype=torch.float32, scale=True)]
        )

    def __call__(self, image, targets):

        if self.training:

            p = random.random()
    
            for idx, (li,ri) in enumerate(zip(self.intervals[:-1],self.intervals[1:])):
                if p>li and p<=ri:
                    if idx == 0:
                        break
                    if idx == 1:
                        # color jitter tranformation
                        image, targets = self.color_jitter(image, targets)
                    if idx == 2:
                        # clockwise rotation tranformation
                        image, targets = self.random_rotate_clockwise(image, targets)
                    if idx == 3:
                        # anitclockwise rotation tranformation
                        image, targets = self.random_rotate_anticlockwise(image, targets)
                    if idx == 4:
                        # horzontal flip
                        image, targets = self.horizontal_flip(image, targets)

        image, targets = self.default_transforms(image, targets)
        return image, targets



In [57]:
class SSDDataset(Dataset):
    def __init__(self, img_folder: str, label_transform: LabelTransform = None, transform: SSDTransform = None):
        
        assert label_transform is not None, "Label transform should not be None"
        
        self.img_paths = [ os.path.join(img_folder,filepath) for filepath in os.listdir(img_folder) if not filepath.endswith(".xml")]
        self.transform = transform
        self.label_transform = label_transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        xml_path = img_path.replace(os.path.splitext(img_path)[1], ".xml")

        image = Image.open(img_path).convert("RGB")
        W,H = image.size
        
        labels,bboxes = [], []
        if os.path.exists(xml_path):
            labels, bboxes = parse_xml(xml_path)

        target = {}
        
        if self.label_transform is not None and labels:
            labels = self.label_transform(labels)

        labels = torch.as_tensor(labels)
            
        target["boxes"] =  tv_tensors.BoundingBoxes(torch.as_tensor(bboxes), format="XYXY", canvas_size=(H,W))
        target["labels"] = labels

        if self.transform is not None:
            image, target = self.transform(image, target)

        return image, target
                

In [58]:
def collate_fn(batch):
    return list(zip(*batch))

### Training

In [None]:
%time
model = torchvision.models.detection.ssd300_vgg16(pretrained = True)

In [65]:
CLASSES = ["signature", "people"]

In [29]:
train_img_dir = "D:/00_Projects/ImageProcessing/ssd/train/"

In [66]:
label_transform = LabelTransform(CLASSES)

In [31]:
transform = SSDTransform(hflip=True, rotate_1=(0,10), rotate_2=(-10,0))

In [33]:
train_ds = SSDDataset(train_img_dir,label_transform, transform)

In [34]:
len(train_ds)

35

In [35]:
train_dl = DataLoader(dataset=train_ds, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [36]:
optimizer = optim.Adam(model.parameters(), lr= 1e-3)

In [45]:
EPOCHS = 5

In [46]:
model.train()
etl = tnrange(EPOCHS)
for epoch in etl:
    epoch_trn_losses = []
    tdl = tqdm(train_dl, total=len(train_dl), leave = True)
    for batch in tdl:
        images, targets = batch
        losses = model(images, targets)
        loss = (losses['bbox_regression'] + losses['classification'])
        tdl.set_postfix(loss=loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epoch_trn_losses.append(loss.item())

    epoch_trn_loss = np.array(epoch_trn_losses).mean()
    etl.set_postfix(epoch_loss = epoch_trn_loss)
        

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

### Testing

In [47]:
model.eval()

SSD(
  (backbone): SSDFeatureExtractorVGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=

In [59]:
test_img_dir = "D:/00_Projects/ImageProcessing/ssd/test/"

In [67]:
test_transform = SSDTransform()

In [68]:
test_ds = SSDDataset(test_img_dir,label_transform, test_transform)

In [69]:
test_dl = DataLoader(dataset=test_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)

In [70]:
for x in test_dl:
    break

In [73]:
model(x[0])

[{'boxes': tensor([[ 24.0123, 190.4505, 296.4680, 396.7329],
          [410.1459,  42.3025, 611.8274, 148.4481],
          [445.9295,  69.1654, 646.8311, 175.2687],
          [386.7603, 194.3256, 659.3284, 399.4393],
          [410.4062,  95.9208, 614.2113, 203.8828],
          [241.3191,  55.5308, 525.8620, 204.5639],
          [  0.0000,  50.1355, 265.2478, 229.2480],
          [385.6357,  41.2019, 660.7468, 247.0745],
          [330.4872,   6.1896, 582.2155, 179.6060],
          [249.3162, 195.1672, 523.9681, 398.5835],
          [171.2823,   7.0833, 454.6725, 155.8632],
          [ 26.7208, 106.3832, 294.6958, 283.6154],
          [ 24.9710,   2.9666, 303.8183, 156.8428],
          [ 96.4392,  56.1436, 381.3443, 204.8586],
          [180.9019, 246.0641, 458.5162, 426.1462],
          [171.1024, 111.2884, 461.0391, 258.0874],
          [172.8290, 167.7478, 464.7236, 314.8942],
          [318.7914, 111.3342, 601.7883, 256.7699],
          [431.6300, 167.0632, 664.0000, 314.5848],
   

In [76]:
label_transform.labels[1]

'signature'