In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchmetrics.classification import MulticlassConfusionMatrix
from torch.utils.data import Dataset, DataLoader

import cv2
import numpy as np
import os
import glob as glob
from PIL import Image

import albumentations as A
from albumentations.pytorch import ToTensorV2

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random

from collections import Counter

from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd


from torch.utils.tensorboard import SummaryWriter
import numpy as np
writer = SummaryWriter()


In [None]:
train_data_path = './dataset/train/images/'
train_lbl_path = './dataset/train/labels/'

valid_data_path = './dataset/valid/images/'
valid_lbl_path = './dataset/valid/labels/'

test_data_path = './dataset/test/images/'

# KITTI Dataset
CLASSES = [ 
    'Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', 'Misc', 'DontCare'
]

NUM_CLASSES = len(CLASSES)
NUM_WORKERS = 4
BATCH_SIZE = 32
RESIZE_TO = 640
EPOCHS = 5
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 1e-4
CONF_THRESHOLD = 0.5
MAP_IOU_THRESH = 0.5
NMS_IOU_THRESH = 0.45
PIN_MEMORY = True
LOAD_MODEL = False
SAVE_MODEL = True

S = [RESIZE_TO // 32, RESIZE_TO // 16]

ANCHORS = [
    [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
    [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
]

scale = 1.1

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

In [None]:
train_transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=int(RESIZE_TO * scale)),
        A.PadIfNeeded(
            min_height=int(RESIZE_TO * scale),
            min_width=int(RESIZE_TO * scale),
            border_mode=cv2.BORDER_CONSTANT,
        ),
        A.RandomCrop(width=RESIZE_TO, height=RESIZE_TO),
        A.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
        A.OneOf(
            [
                A.ShiftScaleRotate(
                    rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
                )
            ],
            p=1.0,
        ),
        A.HorizontalFlip(p=0.5),
        A.Blur(p=0.1),
        A.CLAHE(p=0.1),
        A.Posterize(p=0.1),
        A.ToGray(p=0.1),
        A.ChannelShuffle(p=0.05),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[],),
)

test_transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=RESIZE_TO),
        A.PadIfNeeded(
            min_height=RESIZE_TO, min_width=RESIZE_TO, border_mode=cv2.BORDER_CONSTANT
        ),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[]),
)

In [None]:
""" 
Information about architecture config:
"B" indicating a residual block
"S" is for scale prediction block
"U" is for upsampling the feature map
"""
config = [
    (32, 3, 1),
    (64, 3, 2),
    ["B", 4],
    (128, 3, 2),
    ["B", 6],
    (256, 3, 2),
    ["B", 8],
    (512, 3, 2),
    ["B", 8],
    (1024, 3, 2),
    ["B", 6],
    (512, 1, 1),
    (1024, 3, 1),
    "S",
    (256, 1, 1),
    "U",
    (256, 1, 1),
    (512, 3, 1),
    "S",
]

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky = nn.LeakyReLU(0.1)
        self.use_bn_act = bn_act

    def forward(self, x):
        if self.use_bn_act:
            return self.leaky(self.bn(self.conv(x)))
        else:
            return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels, use_residual=True, num_repeats=1):
        super().__init__()
        self.layers = nn.ModuleList()
        for repeat in range(num_repeats):
            self.layers += [
                nn.Sequential(
                    CNNBlock(channels, channels // 2, kernel_size=1),
                    CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
                )
            ]

        self.use_residual = use_residual
        self.num_repeats = num_repeats

    def forward(self, x):
        for layer in self.layers:
            if self.use_residual:
                x = x + layer(x)
            else:
                x = layer(x)

        return x


class ScalePrediction(nn.Module):
    def __init__(self, num_classes, in_channels=3):
        super().__init__()
        self.pred = nn.Sequential(
            CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
            CNNBlock(
                2 * in_channels, 3 * (num_classes + 5), bn_act=False, kernel_size=1
            ),
        )
        
        self.num_classes = num_classes

    def forward(self, x):
        return (
            self.pred(x)
            .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
            .permute(0, 1, 3, 4, 2)
        )

class MelNet(nn.Module):
    def __init__(self, num_classes, in_channels=3):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.layers = self._create_conv_layers()

    def forward(self, x):
        outputs = []
        route_connections = []
        for layer in self.layers:
            if isinstance(layer, ScalePrediction):
                outputs.append(layer(x))
                continue

            x = layer(x)

            if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
                route_connections.append(x)

            elif isinstance(layer, nn.Upsample):
                x = torch.cat([x, route_connections[-1]], dim=1)
                route_connections.pop()

        return outputs

    def _create_conv_layers(self):
        layers = nn.ModuleList()
        in_channels = self.in_channels
        
        for module in config:
            if isinstance(module, tuple):
                out_channels, kernel_size, stride = module
                layers.append(
                    CNNBlock(
                        in_channels,
                        out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=1 if kernel_size == 3 else 0,
                    )
                )
                in_channels = out_channels

            elif isinstance(module, list):
                num_repeats = module[1]
                layers.append(ResidualBlock(in_channels, num_repeats=num_repeats))

            elif isinstance(module, str):
                if module == "S":
                    layers += [
                        ResidualBlock(in_channels, use_residual=False, num_repeats=1),
                        CNNBlock(in_channels, in_channels // 2, kernel_size=1),
                        ScalePrediction(in_channels=in_channels // 2, num_classes=self.num_classes),
                    ]
                    in_channels = in_channels // 2

                elif module == "U":
                    layers.append(nn.Upsample(scale_factor=2))
                    in_channels = in_channels * 3

        return layers

if __name__ == "__main__": 
    model = MelNet(num_classes=NUM_CLASSES)


In [None]:
"""
Parameters:
    boxes1 (tensor): width and height of the first bounding boxes
    boxes2 (tensor): width and height of the second bounding boxes
Returns:
    tensor: IOU
"""
def iou_width_height(boxes1, boxes2):

    intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
        boxes1[..., 1], boxes2[..., 1]
    )
    union = (
        boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
    )
    return intersection / union

"""
Parameters:
    boxes_preds (tensor): Predictions of Bounding Boxes
    boxes_labels (tensor): Correct labels of Bounding Boxes
Returns:
    tensor: IOU
"""
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):

    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    if box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
    
    return intersection / (box1_area + box2_area - intersection + 1e-6)

"""
Parameters:
    bboxes (list): list of lists containing all bboxes with each bboxes
    specified as [class_pred, prob_score, x1, y1, x2, y2]
    iou_threshold (float)
    threshold (float)
Returns:
    list: bboxes after performing NMS
"""
def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):

    bboxes = [box for box in bboxes if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []
    
    while bboxes:
        chosen_box = bboxes.pop(0) 
        bboxes = [
            box
            for box in bboxes
            if box[0] != chosen_box[0]
            or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format,
            )
            < iou_threshold
        ] 
        
        bboxes_after_nms.append(chosen_box)
         
    return bboxes_after_nms
    
"""
Parameters:
    pred_boxes (list): list of lists containing all bboxes with each bboxes
    specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
    true_boxes (list): Similar as pred_boxes except all the correct ones
    iou_threshold (float): threshold where predicted bboxes is correct
Returns:
    mAP value
"""
def mean_average_precision(
    pred_boxes, true_boxes,epoch, num_classes, iou_threshold=0.5, box_format="midpoint"
):
    classes_precisions = []

    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []
        
        for detection in pred_boxes:
            if detection[1] == c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)
                
        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        
        total_true_bboxes = len(ground_truths)

        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    torch.tensor(detection[3:]),
                    torch.tensor(gt[3:]),
                    box_format=box_format,
                )

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1
                    
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))
        
        classes_precisions.append(torch.trapz(precisions, recalls))
    
    return classes_precisions

def get_evaluation_bboxes(
    loader,
    model,
    iou_threshold,
    anchors,
    threshold,
    box_format="midpoint",
    device="cuda",
):
    model.eval()
    train_idx = 0
    all_pred_boxes = []
    all_true_boxes = []
     
    for batch_idx, (x, labels) in enumerate(tqdm(loader)):
        x = x.to(device)
        
        with torch.no_grad():
            predictions = model(x)
        
        batch_size = x.shape[0]
        bboxes = [[] for _ in range(batch_size)]
        for i in range(2):
            S = predictions[i].shape[2]
            anchor = torch.tensor([*anchors[i]]).to(device) * S
            boxes_scale_i = cells_to_bboxes(
                predictions[i], anchor, S=S, is_preds=True
            )
            for idx, (box) in enumerate(boxes_scale_i):
                bboxes[idx] += box
 
        true_bboxes = cells_to_bboxes(
            labels[1], anchor, S=S, is_preds=False
        ) 
        
        for idx in range(batch_size):
            nms_boxes = non_max_suppression(
                bboxes[idx],
                iou_threshold=iou_threshold,
                threshold=threshold,
                box_format=box_format,
            ) 
            
            for nms_box in nms_boxes:
                all_pred_boxes.append([train_idx] + nms_box)
          
            for box in true_bboxes[idx]:
                if box[1] > threshold:
                    all_true_boxes.append([train_idx] + box)
          
            train_idx += 1
    
    model.train()
    return all_pred_boxes, all_true_boxes


"""
Scales the predictions coming from the model to
be relative to the entire image such that they for example later
can be plotted or.
INPUT:
predictions: tensor of size (N, 3, S, S, num_classes+5)
anchors: the anchors used for the predictions
S: the number of cells the image is divided in on the width (and height)
is_preds: whether the input is predictions or the true bounding boxes
OUTPUT:
converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
                  object score, bounding box coordinates
"""
def cells_to_bboxes(predictions, anchors, S, is_preds=True):
    BATCH_SIZE = predictions.shape[0]
    num_anchors = len(anchors)
    box_predictions = predictions[..., 1:5]
    if is_preds:
        anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
        box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
        box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
        scores = torch.sigmoid(predictions[..., 0:1])
        best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
    else:
        scores = predictions[..., 0:1]
        best_class = predictions[..., 5:6]

    cell_indices = (
        torch.arange(S)
        .repeat(predictions.shape[0], 3, S, 1)
        .unsqueeze(-1)
        .to(predictions.device)
    )
    x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
    y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
    w_h = 1 / S * box_predictions[..., 2:4]
    converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
    return converted_bboxes.tolist()

def check_class_accuracy(model, loader, epoch, threshold):
    model.eval()
    tot_class_preds, correct_class = 0, 0
    tot_noobj, correct_noobj = 0, 0
    tot_obj, correct_obj = 0, 0

    for idx, (x, y) in enumerate(tqdm(loader)):
        x = x.to(DEVICE)
        with torch.no_grad():
            out = model(x)

        for i in range(2):
            y[i] = y[i].to(DEVICE)
            obj = y[i][..., 0] == 1 
            noobj = y[i][..., 0] == 0 

            correct_class += torch.sum(
                torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
            )
            
            tot_class_preds += torch.sum(obj)

            obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
            correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
            tot_obj += torch.sum(obj)
            correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
            tot_noobj += torch.sum(noobj)
            
    class_acc = (correct_class/(tot_class_preds+1e-16))*100
    no_obj_acc = (correct_noobj/(tot_noobj+1e-16))*100
    obj_acc = (correct_obj/(tot_obj+1e-16))*100
    writer.add_scalar('Class Accuracy/train', class_acc, epoch)
    writer.add_scalar('No Obj Accuracy/train', no_obj_acc, epoch)
    writer.add_scalar('Obj Accuracy/train', obj_acc, epoch)
    print(f"Class accuracy is: {(correct_class/(tot_class_preds+1e-16))*100:2f}%")
    print(f"No obj accuracy is: {(correct_noobj/(tot_noobj+1e-16))*100:2f}%")
    print(f"Obj accuracy is: {(correct_obj/(tot_obj+1e-16))*100:2f}%")
    model.train()


def get_mean_std(loader):
    channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0

    for data, _ in tqdm(loader):
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1

    mean = channels_sum / num_batches
    std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5

    return mean, std


def save_checkpoint(model, optimizer, filename="./checkpoint/my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=confiDEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def get_loaders():
    
    train_dataset = MelNetDataset(
        transforms=train_transforms,
        S=[RESIZE_TO // 32, RESIZE_TO // 16],
        img_dir=train_data_path,
        label_dir=train_lbl_path,
        anchors=ANCHORS,
    )
     
    test_dataset = MelNetDataset(
        transforms=test_transforms,
        S=[RESIZE_TO // 32, RESIZE_TO // 16],
        img_dir=valid_data_path,
        label_dir=valid_lbl_path,
        anchors=ANCHORS,
    )
    
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=True,
        drop_last=False,
    )
    
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=False,
        drop_last=False,
    )

    train_eval_dataset = MelNetDataset(
        transforms=test_transforms,
        S=[RESIZE_TO // 32, RESIZE_TO // 16],
        img_dir=test_data_path,
        anchors=ANCHORS,
    )
    
    train_eval_loader = DataLoader(
        dataset=train_eval_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=False,
        drop_last=False,
    )

    return train_loader, test_loader, train_eval_loader

def plot_couple_examples(model, loader, thresh, iou_thresh, anchors):
    model.eval()
    x, y = next(iter(loader))
    x = x.to("cuda")
    with torch.no_grad():
        out = model(x)
        bboxes = [[] for _ in range(x.shape[0])]
        for i in range(2):
            batch_size, A, S, _, _ = out[i].shape
            anchor = anchors[i]
            boxes_scale_i = cells_to_bboxes(
                out[i], anchor, S=S, is_preds=True
            )
            for idx, (box) in enumerate(boxes_scale_i):
                bboxes[idx] += box

        model.train()

    for i in range(batch_size):
        nms_boxes = non_max_suppression(
            bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
        )
        plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)



def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:


def plot_image(image, boxes):
    cmap = plt.get_cmap("tab20b")
    class_labels = CLASSES
    colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
    im = np.array(image)
    height, width, _ = im.shape

    fig, ax = plt.subplots(1)
    
    ax.imshow(im)

    for box in boxes:
        class_pred = box[0] 
        box = box[2:]
        upper_left_x = box[0] - box[2] / 2
        upper_left_y = box[1] - box[3] / 2
        rect = patches.Rectangle(
            (upper_left_x * width, upper_left_y * height),
            box[2] * width,
            box[3] * height,
            linewidth=2,
            edgecolor=colors[int(class_pred)],
            facecolor="none",
        )
        ax.add_patch(rect)
        plt.text(
            upper_left_x * width,
            upper_left_y * height,
            s=class_labels[int(class_pred)],
            color="white",
            verticalalignment="top",
            bbox={"color": colors[int(class_pred)], "pad": 0},
        )

    plt.show()



In [None]:
class MelNetDataset(Dataset):
    def __init__(
        self, img_dir, label_dir, anchors, S = [13, 26], C = 20, transforms=None
    ):
        self.all_images = glob.glob(f"{img_dir}/*")       
        self.labels_path = label_dir
        self.transforms = transforms
        self.S = S
        self.anchors = torch.tensor(anchors[0] + anchors[1])
        self.num_anchors = self.anchors.shape[0]
        self.num_anchors_per_scale = self.num_anchors // 2
        self.C = C
        self.ignore_iou_thresh = 0.5

    def __len__(self):
        return len(self.all_images)
    
    def __getitem__(self, index):
        image_path = self.all_images[index]
        image_name = image_path.split(os.path.sep)[-1]
        image = np.array(Image.open(image_path).convert("RGB"))

        lbl_filename = image_name[:-4] + '.txt'
        lbl_path = os.path.join(self.labels_path, lbl_filename)
        
        boxes = []
        labels = []  
        box_label = []
         
        with open(lbl_path, 'r') as f:
            label_lines = f.readlines()
            for label_line in label_lines:
                lbl = label_line.split(' ')
                xc = float(lbl[1]) 
                yc = float(lbl[2]) 
                w = float(lbl[3]) 
                h = float(lbl[4])  
                box_label.append([xc, yc, w, h, float(lbl[0])]) 
                
        box_label = np.array(box_label)      
        
        if self.transforms:
            augmentations = self.transforms(image = image, bboxes = box_label)
            image = augmentations["image"]
            bboxes = augmentations["bboxes"]
            
        targets = [torch.zeros((self.num_anchors // 2, S, S, 6)) for S in self.S] 
        
        for box in bboxes:
            iou_anchors = iou_width_height(torch.tensor(box[2:4]), self.anchors)
            anchor_indices = iou_anchors.argsort(descending = True, dim = 0)
            
            x, y, width, height, class_label = box
            
            has_anchor = [False] * 2
            
            for anchor_idx in anchor_indices:
                scale_idx = anchor_idx // self.num_anchors_per_scale
                anchor_on_scale = anchor_idx % self.num_anchors_per_scale
                
                S = self.S[scale_idx]
                
                i, j = int(S * y), int(S * x)
                
                anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
                
                if not anchor_taken and not has_anchor[scale_idx]:
                    
                    targets[scale_idx][anchor_on_scale, i, j, 0] = 1
                    
                    x_cell, y_cell = S * x - j, S * y - i
                    
                    width_cell, height_cell = (
                        width * S,
                        height * S,
                    )
                    
                    box_coordinates = torch.tensor(
                        [x_cell, y_cell, width_cell, height_cell]
                    )
                    
                    targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates
                    
                    targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)
                    has_anchor[scale_idx] = True

                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    targets[scale_idx][anchor_on_scale, i, j, 0] = -1

        return image, tuple(targets)


In [None]:
class MelNetLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.entropy = nn.CrossEntropyLoss()
        self.sigmoid = nn.Sigmoid()

        self.lambda_class = 1
        self.lambda_noobj = 10
        self.lambda_obj = 1
        self.lambda_box = 10

    def forward(self, predictions, target, anchors):
        obj = target[..., 0] == 1 
        noobj = target[..., 0] == 0 
        
        # NO OBJECT LOSS 
        no_object_loss = self.bce(
            (predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]),
        )

        # OBJECT LOSS
        anchors = anchors.reshape(1, 3, 1, 1, 2)
        box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)
        ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
        object_loss = self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj])
        
        # BOX COORDINATES
        predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3])
        target[..., 3:5] = torch.log(
            (1e-16 + target[..., 3:5] / anchors)
        )
        box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])

        # CLASS LOSS
        class_loss = self.entropy(
            (predictions[..., 5:][obj]), (target[..., 5][obj].long()),
        )

        return (
            self.lambda_box * box_loss
            + self.lambda_obj * object_loss
            + self.lambda_noobj * no_object_loss
            + self.lambda_class * class_loss
        )

In [None]:
class ConfusionMatrix:
    def __init__(self, nc, conf=0.25, iou_thres=0.45):
        self.matrix = np.zeros((nc + 1, nc + 1))
        self.nc = nc
        self.conf = conf
        self.iou_thres = iou_thres

    def process_batch(self, detections, labels):
        if detections is None:
            gt_classes = labels.int()
            for gc in gt_classes:
                self.matrix[self.nc, gc] += 1
            return

        detections = detections[detections[:, 4] > self.conf]
        gt_classes = labels[:, 0].int()
        detection_classes = detections[:, 5].int()
        iou = box_iou(labels[:, 1:], detections[:, :4])

        x = torch.where(iou > self.iou_thres)
        if x[0].shape[0]:
            matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
            if x[0].shape[0] > 1:
                matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
        else:
            matches = np.zeros((0, 3))

        n = matches.shape[0] > 0
        m0, m1, _ = matches.transpose().astype(int)
        for i, gc in enumerate(gt_classes):
            j = m0 == i
            if n and sum(j) == 1:
                self.matrix[detection_classes[m1[j]], gc] += 1
            else:
                self.matrix[self.nc, gc] += 1

        if n:
            for i, dc in enumerate(detection_classes):
                if not any(m1 == i):
                    self.matrix[dc, self.nc] += 1

    def tp_fp(self):
        tp = self.matrix.diagonal() 
        fp = self.matrix.sum(1) - tp
       
        return tp[:-1], fp[:-1]

    def plot(self, normalize=True, save_dir='', names=()):
        import seaborn as sn

        array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1)
        array[array < 0.005] = np.nan 

        fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
        nc, nn = self.nc, len(names)
        sn.set(font_scale=1.0 if nc < 50 else 0.8)
        labels = (0 < nn < 99) and (nn == nc)
        ticklabels = (names + ['background']) if labels else 'auto'
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            sn.heatmap(array,
                       ax=ax,
                       annot=nc < 30,
                       annot_kws={
                           'size': 8},
                       cmap='Blues',
                       fmt='.2f',
                       square=True,
                       vmin=0.0,
                       xticklabels=ticklabels,
                       yticklabels=ticklabels).set_facecolor((1, 1, 1))
        ax.set_xlabel('True')
        ax.set_ylabel('Predicted')
        ax.set_title('Confusion Matrix')
        fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
        plt.close(fig)

    def print(self):
        for i in range(self.nc + 1):
            print(' '.join(map(str, self.matrix[i])))

In [None]:
import torch
import os

torch.backends.cudnn.benchmark = True

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "caching_allocator"


def train_fn(train_loader, model, epoch, optimizer, loss_fn, scaler, scaled_anchors):
    loop = tqdm(train_loader, leave=True)
    losses = []
    for batch_idx, (x, y) in enumerate(loop):
        x = x.to(DEVICE)
        y0, y1 = (
            y[0].to(DEVICE),
            y[1].to(DEVICE),
        )

        with torch.cuda.amp.autocast():
            out = model(x)
            loss = (
                loss_fn(out[0], y0, scaled_anchors[0])
                + loss_fn(out[1], y1, scaled_anchors[1])
            )

        losses.append(loss.item())
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        mean_loss = sum(losses) / len(losses)
        loop.set_postfix(loss=mean_loss)
        
    writer.add_scalar('Loss/train', mean_loss, epoch)



def main():
    model = MelNet(num_classes=NUM_CLASSES).to(DEVICE)
    optimizer = optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )
    loss_fn = MelNetLoss()
    scaler = torch.cuda.amp.GradScaler()

    train_loader, test_loader, train_eval_loader = get_loaders()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_FILE, model, epoch, optimizer, LEARNING_RATE
        )

    scaled_anchors = (
        torch.tensor(ANCHORS)
        * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    ).to(DEVICE)

    for epoch in range(EPOCHS):
        current_epoch = epoch + 1
        print(f"Currently epoch {current_epoch}")
        train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)

        if SAVE_MODEL:
            save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
 
        if current_epoch > 1 and current_epoch % 25 == 0:
            check_class_accuracy(model, test_loader, epoch, threshold=CONF_THRESHOLD)
            
            pred_boxes, true_boxes = get_evaluation_bboxes(
                test_loader,
                model,
                iou_threshold=NMS_IOU_THRESH,
                anchors=ANCHORS,
                threshold=CONF_THRESHOLD,
            )
            
            all_classes_average_precisions = mean_average_precision(
                pred_boxes,
                true_boxes,
                epoch,
                iou_threshold=MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=NUM_CLASSES
            )
                       
            for i in range(NUM_CLASSES):
                print('class {}: mAP value is: {}'.format(CLASSES[i], all_classes_average_precisions[i]))
                writer.add_scalar('mAP_{}/train'.format(CLASSES[i]), all_classes_average_precisions[i], epoch)
                
            mapval = sum(all_classes_average_precisions) / len(all_classes_average_precisions)
            print(f"mAP: {mapval.item()}")
            writer.add_scalar('mAP_50/train_50', mapval.item(), epoch)
            
            writer.add_scalars('mAP/train', {'{}'.format(CLASSES[0]):all_classes_average_precisions[0],
                                             '{}'.format(CLASSES[1]):all_classes_average_precisions[1],
                                             '{}'.format(CLASSES[2]):all_classes_average_precisions[2],
                                             '{}'.format(CLASSES[3]):all_classes_average_precisions[3],
                                             '{}'.format(CLASSES[4]):all_classes_average_precisions[4],
                                             '{}'.format(CLASSES[5]):all_classes_average_precisions[5],
                                             '{}'.format(CLASSES[6]):all_classes_average_precisions[6],
                                             '{}'.format(CLASSES[7]):all_classes_average_precisions[7],
                                             '{}'.format(CLASSES[8]):all_classes_average_precisions[8],
                                             'all': mapval.item()
                                              } , epoch) 
            
            
            model.train()


if __name__ == "__main__":
    main()

In [None]:
%load_ext tensorboard
%tensorboard --logdir=runs