In [107]:
try:
    import wandb
except:
    !pip install wandb
    import wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33map-wt[0m (use `wandb login --relogin` to force relogin)


In [108]:
try:
    import torchmetrics
except:
    !pip install torchmetrics
    import torchmetrics

In [109]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [110]:
import pandas as pd
import numpy as np
from copy import deepcopy
import os
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision
import ast
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
from torchmetrics.detection.map import MeanAveragePrecision


# this should probably be changed to something smart, right?
KAGGLE_PATH_ANNOTATIONS = '/kaggle/input/tensorflow-great-barrier-reef/train.csv'
KAGGLE_PATH_IMG_DIR = '/kaggle/input/tensorflow-great-barrier-reef/train_images/'
LOCAL_PATH_ANNOTATIONS = 'data/train.csv'
LOCAL_PATH_IMG_DIR = 'data/train_images/'
COLAB_PATH_ANNOTATIONS = '/content/drive/MyDrive/data/train.csv'
COLAB_PATH_IMG_DIR = '/content/drive/MyDrive/data/train_images/'

wandb.config = {
  "learning_rate": 0.001,
  "epochs": 2,
  "batch_size": 2,
  "momentum": 0.9,
  "weight_decay": 0.0005, 
  "confidence_threshold": 0.5 # save a bounding box if model returned confidence above this threshold
}

In [111]:
class StarfishDataset(Dataset):
    def __init__(self,
                 annotations_file=COLAB_PATH_ANNOTATIONS,
                 img_dir=COLAB_PATH_IMG_DIR
                 ):
        self.img_labels = pd.read_csv(annotations_file)
        self.annotated = self.img_labels[self.img_labels['annotations'] != '[]']  # get only annotated frames
        self.img_dir = img_dir

    def __len__(self):
        return len(self.annotated)

    def __getitem__(self, idx):
        image = read_image(os.path.join(self.img_dir, 'video_{}'.format(self.annotated.iloc[idx][0]),
                                        '{}.jpg'.format(self.annotated.iloc[idx][2])))
        min_image = image.min()
        max_image = image.max()
        # normalize image to 0-1 - required by torchvision
        image -= min_image
        image = torch.FloatTensor(image/max_image)
        labels = self.annotated.iloc[idx][-1]
        labels = ast.literal_eval(labels)
        coords = []
        for parsed_label in labels:
            x1, y1 = parsed_label['x'], parsed_label['y']
            x2, y2 = x1+parsed_label['width'], y1+parsed_label['height']
            coords.append([x1, y1, x2, y2])

        boxes = torch.FloatTensor(coords)
        labels = torch.LongTensor([1 for _ in range(len(coords))]) # label has to be integer, since we have only one label I coded it as 1 for simplicity
        return image, boxes, labels

# dataset = StarfishDataset()
# dataset.__getitem__(0)


In [112]:
def collate_fn(batch):
    targets = []
    images = []
    for imgs, boxes, labels in batch:
        images.append(imgs)
        d = {}
        d['boxes'] = boxes
        d['labels'] = labels
        targets.append(d)
    return images, targets

def slice_output(output: dict, confidence_threshold: float = wandb.config['confidence_threshold']) -> dict:
    """
    this method is responsible for validating models output w.r.t confidence_threshold defined above.
    It accepts an output dictionary from model, namely {'boxes':[], 'labels':[], 'scores':[]}
    It returns a dictionary sliced to items with score above confidence_threshold
    """

    num_valid_elements = np.sum(np.array(output['scores']) >= confidence_threshold)
    # temporary option to make sure, that it returns at least one element, although it should probably be fixed,
    # should there be any frames where there is no starfish
    if num_valid_elements == 0:
        num_valid_elements = 1
    res = {}
    for key, value in output.items():
        res[key] = value[:num_valid_elements]
    return res

In [113]:
# https://towardsdatascience.com/evaluating-performance-of-an-object-detection-model-137a349c517b

def calc_iou(gt_bbox, pred_bbox):
    '''
    This function takes the predicted bounding box and ground truth bounding box and 
    return the IoU ratio
    '''
    x_topleft_gt, y_topleft_gt, x_bottomright_gt, y_bottomright_gt= gt_bbox
    print(pred_bbox)
    x_topleft_p, y_topleft_p, x_bottomright_p, y_bottomright_p= pred_bbox
    
    if (x_topleft_gt > x_bottomright_gt) or (y_topleft_gt> y_bottomright_gt):
        raise AssertionError("Ground Truth Bounding Box is not correct")
    if (x_topleft_p > x_bottomright_p) or (y_topleft_p> y_bottomright_p):
        raise AssertionError("Predicted Bounding Box is not correct",x_topleft_p, x_bottomright_p,y_topleft_p,y_bottomright_gt)
        
         
    #if the GT bbox and predcited BBox do not overlap then iou=0
    if(x_bottomright_gt< x_topleft_p):
        # If bottom right of x-coordinate  GT  bbox is less than or above the top left of x coordinate of  the predicted BBox
        
        return 0.0
    if(y_bottomright_gt< y_topleft_p):  # If bottom right of y-coordinate  GT  bbox is less than or above the top left of y coordinate of  the predicted BBox
        
        return 0.0
    if(x_topleft_gt> x_bottomright_p): # If bottom right of x-coordinate  GT  bbox is greater than or below the bottom right  of x coordinate of  the predcited BBox
        
        return 0.0
    if(y_topleft_gt> y_bottomright_p): # If bottom right of y-coordinate  GT  bbox is greater than or below the bottom right  of y coordinate of  the predcited BBox
        
        return 0.0
    
    
    GT_bbox_area = (x_bottomright_gt -  x_topleft_gt + 1) * (  y_bottomright_gt -y_topleft_gt + 1)
    Pred_bbox_area =(x_bottomright_p - x_topleft_p + 1 ) * ( y_bottomright_p -y_topleft_p + 1)
    
    x_top_left =np.max([x_topleft_gt, x_topleft_p])
    y_top_left = np.max([y_topleft_gt, y_topleft_p])
    x_bottom_right = np.min([x_bottomright_gt, x_bottomright_p])
    y_bottom_right = np.min([y_bottomright_gt, y_bottomright_p])
    
    intersection_area = (x_bottom_right- x_top_left + 1) * (y_bottom_right-y_top_left  + 1)
    
    union_area = (GT_bbox_area + Pred_bbox_area - intersection_area)
   
    return intersection_area/union_area

def calc_precision_recall(image_results):
    """Calculates precision and recall from the set of images
    Args:
        img_results (dict): dictionary formatted like:
            {
                'img_id1': {'true_pos': int, 'false_pos': int, 'false_neg': int},
                'img_id2': ...
                ...
            }
    Returns:
        tuple: of floats of (precision, recall)
    """
    true_positive=0
    false_positive=0
    false_negative=0
    for img_id, res in image_results.items():
        true_positive +=res['true_positive']
        false_positive += res['false_positive']
        false_negative += res['false_negative']
        try:
            precision = true_positive/(true_positive+ false_positive)
        except ZeroDivisionError:
            precision=0.0
        try:
            recall = true_positive/(true_positive + false_negative)
        except ZeroDivisionError:
            recall=0.0
    return (precision, recall)

def get_single_image_results(gt_boxes, pred_boxes, iou_thr):
    """Calculates number of true_pos, false_pos, false_neg from single batch of boxes.
    Args:
        gt_boxes (list of list of floats): list of locations of ground truth
            objects as [xmin, ymin, xmax, ymax]
        pred_boxes (dict): dict of dicts of 'boxes' (formatted like `gt_boxes`)
            and 'scores'
        iou_thr (float): value of IoU to consider as threshold for a
            true prediction.
    Returns:
        dict: true positives (int), false positives (int), false negatives (int)
    """
    all_pred_indices= range(len(pred_boxes))
    all_gt_indices=range(len(gt_boxes))
    if len(all_pred_indices)==0:
        tp=0
        fp=0
        fn=0
        return {'true_positive':tp, 'false_positive':fp, 'false_negative':fn}
    if len(all_gt_indices)==0:
        tp=0
        fp=0
        fn=0
        return {'true_positive':tp, 'false_positive':fp, 'false_negative':fn}
    
    gt_idx_thr=[]
    pred_idx_thr=[]
    ious=[]
    # print('PREDBOXES', pred_boxes)
    # print('GTBOXES', gt_boxes)
    # trying to adjust it to our data
    # for preds, targets in zip(pred_boxes.items(), gt_boxes.items()):

    for ipb, pred_box in enumerate(pred_boxes):
        for igb, gt_box in enumerate(gt_boxes):
            # print('predbox1', pred_boxes[pred_box])
            # print('gtbox1', gt_boxes[gt_box])
            iou= calc_iou(gt_box, pred_box)
            
            if iou >iou_thr:
                gt_idx_thr.append(igb)
                pred_idx_thr.append(ipb)
                ious.append(iou)
    iou_sort = np.argsort(ious)[::1]
    if len(iou_sort)==0:
        tp=0
        fp=0
        fn=0
        return {'true_positive':tp, 'false_positive':fp, 'false_negative':fn}
    else:
        gt_match_idx=[]
        pred_match_idx=[]
        for idx in iou_sort:
            gt_idx=gt_idx_thr[idx]
            pr_idx= pred_idx_thr[idx]
            # If the boxes are unmatched, add them to matches
            if(gt_idx not in gt_match_idx) and (pr_idx not in pred_match_idx):
                gt_match_idx.append(gt_idx)
                pred_match_idx.append(pr_idx)
        tp= len(gt_match_idx)
        fp= len(pred_boxes) - len(pred_match_idx)
        fn = len(gt_boxes) - len(gt_match_idx)
    return {'true_positive': tp, 'false_positive': fp, 'false_negative': fn}

In [114]:
torch.manual_seed(23)

# IF YOU WANT TO RUN PROPER MODEL LEARNING, MAKE SURE TO CHANGE DATASET SIZES

dataset = StarfishDataset()
train_size = int(0.09 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# extract only small part of the data for faster learning / testing process
train_size = int(0.8 * len(train_dataset))
test_size = len(train_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(train_dataset, [train_size, test_size])


print('Train dataset: {} instances, test dataset: {}'.format(len(train_dataset), len(test_dataset)))


train_dataloader = DataLoader(
    train_dataset, batch_size=wandb.config['batch_size'], shuffle=False, num_workers=1, collate_fn = collate_fn)
test_dataloader = DataLoader(
    test_dataset, batch_size=wandb.config['batch_size'], shuffle=False, num_workers=1,  collate_fn = collate_fn)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
cpu = torch.device('cpu')
print('Used device: {}'.format(device))

num_classes = 2  # starfish and not starfish I guess

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=wandb.config['learning_rate'], momentum=wandb.config['momentum'], weight_decay=wandb.config['weight_decay'])
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

Train dataset: 353 instances, test dataset: 89
Used device: cuda


In [115]:
# https://pytorch.org/vision/stable/models.html#runtime-characteristics see Faster R-CNN for the details of this model, what it requires, returns, etc

# https://github.com/pytorch/vision/blob/main/references/detection/engine.py probably see training and eval loops here

# wandb.init(project="great-barrier-reef", entity="ap-wt", config = wandb.config)

# wandb.watch(model, log="all")

for e in tqdm(range(wandb.config['epochs'])):
    print('\n')
    model.train()

    for idx, (images, targets) in enumerate(train_dataloader):

        images = list(image.to(device) for image in images)

        for d in targets:
            d['boxes'] = d['boxes'].to(device)
            d['labels'] = d['labels'].to(device)

        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    lr_scheduler.step()
            
    model.eval()
    with torch.no_grad():
        gt_boxes = dict()
        pred_boxes = dict()
        for idx, (images, targets) in enumerate(test_dataloader):

            images = list(image.to(device) for image in images)
            predictions = model(images)
            outputs = [{k: v.to(cpu) for k, v in t.items()} for t in predictions]

            # TODO: add some comparison with 'targets' perhaps
            outputs = [slice_output(out) for out in outputs]

            # print('Output', outputs)
            # print('Targets', targets)
            gt_boxes[idx] = [d['boxes'].tolist()[0] for d in targets]
            tmp_pred_boxes = {'boxes': [], "scores": []}
            for d in outputs:
                tmp_pred_boxes['boxes'].append(d['boxes'].tolist()[0])
                tmp_pred_boxes['scores'].append(d['scores'].tolist()[0])
            pred_boxes[idx] = tmp_pred_boxes
            # print('gt_boxes', gt_boxes)
            # print('pred_boxes', pred_boxes)
        tmp = {}
        for idx, (gt, prd) in enumerate(zip(gt_boxes, pred_boxes)):
            res = get_single_image_results(gt_boxes, pred_boxes, 0.5)
            print('res', res)
            tmp[idx] = res
        pr = calc_precision_recall(tmp)
        print('pr', pr)
        break

            # metric = MeanAveragePrecision()
            # metric.update(outputs, targets)
            # metrics = metric.compute()
            # if idx % 100 == 0:
                # wandb.log({'MAP':metrics['map'], 'MAR_1':metrics['mar_1']})
        


    lr_scheduler.step()
# wandb.finish()

  0%|          | 0/2 [00:00<?, ?it/s]





  0%|          | 0/2 [04:53<?, ?it/s]


TypeError: ignored

In [120]:
tmp = {}
for idx, (gt, prd) in enumerate(zip(gt_boxes.values(), pred_boxes.values())):
    print(gt, prd)
    res = get_single_image_results(gt, prd['boxes'], 0.5)
    print('res', res)
    tmp[idx] = res
pr = calc_precision_recall(tmp)
print('pr', pr)


[[575.0, 197.0, 645.0, 245.0], [343.0, 405.0, 377.0, 440.0]] {'boxes': [[589.599365234375, 200.47019958496094, 634.9077758789062, 239.22512817382812], [209.7011260986328, 544.3060913085938, 274.01300048828125, 604.1810913085938]], 'scores': [0.3593592941761017, 0.428097128868103]}
[589.599365234375, 200.47019958496094, 634.9077758789062, 239.22512817382812]
[589.599365234375, 200.47019958496094, 634.9077758789062, 239.22512817382812]
[209.7011260986328, 544.3060913085938, 274.01300048828125, 604.1810913085938]
[209.7011260986328, 544.3060913085938, 274.01300048828125, 604.1810913085938]
res {'true_positive': 1, 'false_positive': 1, 'false_negative': 1}
[[510.0, 0.0, 545.0, 38.0], [152.0, 221.0, 196.0, 253.0]] {'boxes': [[424.8997497558594, 532.1857299804688, 495.3158874511719, 608.66943359375], [254.62112426757812, 291.3923645019531, 296.9801025390625, 329.65924072265625]], 'scores': [0.3319912850856781, 0.3741723597049713]}
[424.8997497558594, 532.1857299804688, 495.3158874511719, 608

In [None]:
d = {0: [[575.0, 197.0, 645.0, 245.0], [343.0, 405.0, 377.0, 440.0]], 1: [[510.0, 0.0, 545.0, 38.0], [152.0, 221.0, 196.0, 253.0]], 2: [[87.0, 269.0, 154.0, 313.0], [662.0, 205.0, 691.0, 231.0]], 3: [[869.0, 591.0, 929.0, 654.0], [526.0, 653.0, 577.0, 701.0]], 4: [[564.0, 358.0, 648.0, 422.0], [466.0, 653.0, 537.0, 718.0]], 5: [[521.0, 325.0, 562.0, 367.0], [801.0, 389.0, 832.0, 414.0]], 6: [[656.0, 132.0, 693.0, 172.0], [868.0, 107.0, 905.0, 134.0]], 7: [[854.0, 484.0, 886.0, 522.0], [99.0, 144.0, 159.0, 208.0]], 8: [[404.0, 344.0, 450.0, 383.0], [78.0, 572.0, 150.0, 625.0]], 9: [[173.0, 578.0, 240.0, 631.0], [322.0, 469.0, 344.0, 493.0]], 10: [[310.0, 229.0, 391.0, 323.0], [499.0, 312.0, 561.0, 372.0]], 11: [[790.0, 156.0, 833.0, 185.0], [321.0, 35.0, 351.0, 67.0]], 12: [[401.0, 283.0, 467.0, 346.0], [413.0, 169.0, 444.0, 206.0]], 13: [[161.0, 87.0, 219.0, 141.0], [606.0, 351.0, 661.0, 407.0]], 14: [[435.0, 145.0, 492.0, 202.0], [524.0, 28.0, 543.0, 48.0]], 15: [[624.0, 442.0, 684.0, 498.0], [1011.0, 491.0, 1047.0, 523.0]], 16: [[0.0, 319.0, 25.0, 377.0], [401.0, 151.0, 448.0, 194.0]], 17: [[315.0, 311.0, 355.0, 341.0], [394.0, 576.0, 435.0, 612.0]], 18: [[243.0, 618.0, 289.0, 660.0], [227.0, 457.0, 255.0, 486.0]], 19: [[917.0, 333.0, 959.0, 368.0], [593.0, 100.0, 630.0, 136.0]], 20: [[522.0, 667.0, 566.0, 702.0], [1012.0, 328.0, 1053.0, 369.0]], 21: [[129.0, 167.0, 163.0, 201.0], [649.0, 166.0, 674.0, 188.0]], 22: [[191.0, 409.0, 228.0, 433.0], [478.0, 138.0, 516.0, 179.0]], 23: [[452.0, 213.0, 510.0, 250.0], [252.0, 375.0, 318.0, 444.0]], 24: [[510.0, 355.0, 576.0, 418.0], [316.0, 478.0, 353.0, 515.0]], 25: [[778.0, 553.0, 828.0, 612.0], [420.0, 162.0, 451.0, 199.0]], 26: [[868.0, 425.0, 899.0, 450.0], [70.0, 511.0, 129.0, 566.0]], 27: [[178.0, 391.0, 225.0, 422.0], [166.0, 496.0, 213.0, 539.0]], 28: [[592.0, 47.0, 750.0, 172.0], [579.0, 462.0, 626.0, 515.0]], 29: [[700.0, 262.0, 734.0, 295.0], [545.0, 590.0, 609.0, 655.0]], 30: [[100.0, 588.0, 154.0, 636.0], [578.0, 346.0, 658.0, 409.0]], 31: [[695.0, 425.0, 760.0, 503.0], [152.0, 66.0, 205.0, 114.0]], 32: [[54.0, 166.0, 94.0, 203.0], [341.0, 385.0, 360.0, 401.0]], 33: [[358.0, 307.0, 399.0, 338.0], [565.0, 262.0, 600.0, 294.0]], 34: [[532.0, 301.0, 571.0, 339.0], [169.0, 2.0, 221.0, 35.0]], 35: [[80.0, 199.0, 154.0, 281.0], [414.0, 255.0, 458.0, 284.0]], 36: [[957.0, 142.0, 1027.0, 205.0], [108.0, 367.0, 158.0, 394.0]], 37: [[360.0, 264.0, 397.0, 291.0], [815.0, 312.0, 864.0, 371.0]], 38: [[517.0, 539.0, 588.0, 607.0], [767.0, 63.0, 807.0, 88.0]], 39: [[521.0, 138.0, 608.0, 209.0], [193.0, 83.0, 228.0, 126.0]], 40: [[39.0, 199.0, 102.0, 238.0], [878.0, 72.0, 940.0, 119.0]], 41: [[520.0, 325.0, 560.0, 359.0], [1099.0, 189.0, 1176.0, 260.0]], 42: [[507.0, 554.0, 578.0, 622.0], [741.0, 682.0, 784.0, 719.0]], 43: [[956.0, 570.0, 992.0, 611.0], [426.0, 201.0, 483.0, 258.0]], 44: [[575.0, 498.0, 617.0, 540.0]]}
for idx, x in enumerate(d):
    print(idx, d[x])


In [None]:
# torch.save(model.state_dict(), 'models/FastRCNN.pt')