## Import libraries

In [None]:
import numpy as np 
import pandas as pd
from bs4 import BeautifulSoup
import torchvision
from torchvision import transforms, datasets, models
import torch
from PIL import Image
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
def generate_box(obj):
    xmin = int(obj.find('xmin').text)
    ymin = int(obj.find('ymin').text)
    xmax = int(obj.find('xmax').text)
    ymax = int(obj.find('ymax').text)
    return [xmin, ymin, xmax, ymax]

def generate_label(obj):
    if obj.find('name').text == "with_mask":
        return 1
    elif obj.find('name').text == "without_mask":
        return 2
    return 3

def generate_target(file): 
    with open(file) as f:
        data = f.read()
        soup = BeautifulSoup(data, 'xml')
        objects = soup.find_all('object')
        num_objs = len(objects)
        boxes = []
        labels = []
        for i in objects:
            boxes.append(generate_box(i))
            labels.append(generate_label(i))
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        return target

class MaskDataset(object):
    def __init__(self, imgs, labels, transforms, base_path):
        self.imgs = imgs
        self.labels = labels
        self.transforms = transforms
        self.base_path = base_path

    def __getitem__(self, idx):
        img_path = self.base_path + "/images/"+ self.imgs[idx]
        label_path = self.base_path + "/annotations/" + self.labels[idx]
        img = Image.open(img_path).convert("RGB") 
        target = generate_target(label_path)
        
        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.imgs)

### modify to your own base path

```
base path
├── images
│   ├── maksssksksss0.png
│   ├── maksssksksss1.png
│   └── ...
└── annotations
   ├── maksssksksss0.xml
   ├── maksssksksss1.xml
   └── ...
``` 

In [None]:
base_path = "/your/own/base/path/"
imgs = list(sorted(os.listdir(base_path+"images")))
labels = list(sorted(os.listdir(base_path + "annotations")))

## split balanced train/validation set

In [None]:
from collections import defaultdict

def return_idx(lbl, val_count = 2):
    result = defaultdict(list)
    for idx,lb in enumerate(lbl):
        anp = base_path+"annotations/"+lb
        target = generate_target(anp)
        label = str(list(set(target["labels"].cpu().numpy())))
        
        if label in result:
            if len(result[label]) == val_count and label !='[1]':
                pass
            else:
                result[label].append(idx)
        else:
            result[label].append(idx)
    return result
    
a = return_idx(labels)
class_1 = a['[1]'][2:]
sampleList = random.sample(class_1, 500)

In [None]:
val_list = []
for aa in a.values():
    val_list.extend(aa)

In [None]:
alls = list(range(0,len(imgs)))
train_list = [x for x in alls if x not in val_list and x not in sampleList]

In [None]:
train_transform = transforms.Compose([
        transforms.ToTensor()
    ])
valid_transform = transforms.Compose([
        transforms.ToTensor()
    ])

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))


train_dataset = MaskDataset([imgs[i] for i in train_list], [labels[i] for i in train_list],train_transform,base_path)
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

valid_dataset = MaskDataset([imgs[i] for i in val_list], [labels[i] for i in val_list],valid_transform,base_path)
valid_data_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

# test_dataset = MaskDataset(imgs[-5:], labels[-5:],valid_transform)
# test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

len(train_dataset),len(valid_dataset)

In [None]:
torch.cuda.is_available()

True

## Model Training

In [None]:
def get_model_instance_segmentation(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1)

    return model
model = get_model_instance_segmentation(3)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
import time
from tqdm.notebook import tqdm
import matplotlib.patches as mpatches

num_epochs = 25
model.to(device)
    
# parameters
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.02, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

In [None]:
model_name = "fastrcnn_res50_epoch25"
resume = 'fastrcnn_res50_epoch25.pth'
# if you resume
if os.path.isfile(resume):
      print("=> loading checkpoint '{}'".format(model_name))
      checkpoint = torch.load(resume)
      start_epoch = checkpoint['epoch']
      lr_scheduler.load_state_dict(checkpoint['scheduler'])
      model.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print("=> loaded checkpoint '{}' (epoch {})" .format(model_name, start_epoch))
else:
      print("=> no checkpoint found at '{}'".format(model_name))

In [None]:
total_train_loss = []
total_valid_loss = []
start_time = time.time()
for epoch in range(num_epochs):
    print(f'Epoch :{epoch + 1}')
    train_loss = []
    valid_loss = []
    for imgs, annotations in tqdm(train_data_loader):
        model.train()
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())        

        optimizer.zero_grad()
        losses.backward()
        optimizer.step() 
        train_loss.append(losses.item()) 
    epoch_train_loss = np.mean(train_loss)
    total_train_loss.append(epoch_train_loss)
    print(f'Epoch train loss is {epoch_train_loss}')

    for imgs, annotations in tqdm(valid_data_loader):
        with torch.no_grad():
          imgs = list(img.to(device) for img in imgs)
          annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
          loss_dict = model(imgs, annotations)
          losses = sum(loss for loss in loss_dict.values())
          valid_loss.append(losses.item()) 

    epoch_valid_loss = np.mean(valid_loss)
    total_valid_loss.append(epoch_valid_loss)
    print(f'Epoch valid loss is {epoch_valid_loss}')
    lr_scheduler.step()

    plt.plot(range(len(total_train_loss)), total_train_loss, 'b', range(len(total_valid_loss)), total_valid_loss,'r')
    red_patch = mpatches.Patch(color='red', label='Validation')
    blue_patch = mpatches.Patch(color='blue', label='Training')
    plt.legend(handles=[red_patch, blue_patch])
    plt.show()

time_elapsed = time.time() - start_time
print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

In [None]:
model_name = "fastrcnn_res50_epoch25"
filename = f'{model_name}.pth'
state={
      'epoch': num_epochs,
      'state_dict': model.state_dict(),
      'optimizer' : optimizer.state_dict(),
      'scheduler': lr_scheduler.state_dict(),
      }
torch.save(state, filename)

In [None]:
torch.cuda.empty_cache()

## Plot image

In [None]:
# use this function for plot image with annotation
def plot_image(img_tensor, annotation, mode = "pred"):
    
    fig,ax = plt.subplots(1)
    img = img_tensor.cpu().data

    if mode=="pred":
        mask=annotation["scores"]>0.5
    else:
        mask=annotation["labels"]>0
        
    # Display the image
    ax.imshow(img.permute(1, 2, 0))
    
    for (box,label) in zip(annotation["boxes"][mask],annotation["labels"][mask]):
        xmin, ymin, xmax, ymax = box
        if label==1:
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='b',facecolor='none')
            print("with_mask")
        elif label==2:
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none')
            print("without_mask")
        else:
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none')
            print("mask_weared_incorrect")

        ax.add_patch(rect)
        ax.axis("off")
    plt.show()

## For testset inference

In [None]:
iterations=2
dataloader_iterator = iter(test_data_loader)
for i in range(iterations):     
    try:
        imgs, annotations = next(dataloader_iterator)
    except:
        dataloader_iterator = iter(test_data_loader)
        imgs, annotations = next(dataloader_iterator)
imgs = list(img.to(device) for img in imgs)
annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]

In [None]:
a = imgs[0].cpu().numpy()
b = np.transpose(a,(1,2,0))
plt.imshow(b)

In [None]:
model.eval()
preds = model(list(imgs[0][None, :, :]))
preds

In [None]:
n = 0
print("Prediction")
plot_image(imgs[n], preds[n])
print("Target")
plot_image(imgs[n], annotations[n], mode="target")