In [1]:
try:
    import wandb
except:
    !pip install wandb
    import wandb
!wandb login

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
Aborted!


In [26]:
try:
    import torchmetrics
except:
    !pip install torchmetrics
    import torchmetrics

In [27]:
# from google.colab import drive
# drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [54]:
import pandas as pd
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision
import ast
import torch
from tqdm import tqdm

from torchmetrics.detection.map import MeanAveragePrecision


# this should probably be changed to something smart, right?
KAGGLE_PATH_ANNOTATIONS = '/kaggle/input/tensorflow-great-barrier-reef/train.csv'
KAGGLE_PATH_IMG_DIR = '/kaggle/input/tensorflow-great-barrier-reef/train_images/'
LOCAL_PATH_ANNOTATIONS = 'data/train.csv'
LOCAL_PATH_IMG_DIR = 'data/train_images/'
COLAB_PATH_ANNOTATIONS = '/content/drive/MyDrive/data/train.csv'
COLAB_PATH_IMG_DIR = '/content/drive/MyDrive/data/train_images/'

wandb.config = {
  "learning_rate": 0.001,
  "epochs": 2,
  "batch_size": 2,
  "momentum": 0.9,
  "weight_decay": 0.0005, 
  "confidence_threshold": 0.7 # save a bounding box if model returned confidence above this threshold
}

In [55]:
class StarfishDataset(Dataset):
    def __init__(self,
                 annotations_file=COLAB_PATH_ANNOTATIONS,
                 img_dir=COLAB_PATH_IMG_DIR
                 ):
        self.img_labels = pd.read_csv(annotations_file)
        self.annotated = self.img_labels[self.img_labels['annotations'] != '[]']  # get only annotated frames
        self.img_dir = img_dir

    def __len__(self):
        return len(self.annotated)

    def __getitem__(self, idx):
        image = read_image(os.path.join(self.img_dir, 'video_{}'.format(self.annotated.iloc[idx][0]),
                                        '{}.jpg'.format(self.annotated.iloc[idx][2])))
        min_image = image.min()
        max_image = image.max()
        # normalize image to 0-1 - required by torchvision
        image -= min_image
        image = torch.FloatTensor(image/max_image)
        labels = self.annotated.iloc[idx][-1]
        labels = ast.literal_eval(labels)
        coords = []
        for parsed_label in labels:
            x1, y1 = parsed_label['x'], parsed_label['y']
            x2, y2 = x1+parsed_label['width'], y1+parsed_label['height']
            coords.append([x1, y1, x2, y2])

        boxes = torch.FloatTensor(coords)
        labels = torch.LongTensor([1 for _ in range(len(coords))]) # label has to be integer, since we have only one label I coded it as 1 for simplicity
        return image, boxes, labels

# dataset = StarfishDataset()
# dataset.__getitem__(0)


In [56]:
def collate_fn(batch):
    targets = []
    images = []
    for imgs, boxes, labels in batch:
        images.append(imgs)
        d = {}
        d['boxes'] = boxes
        d['labels'] = labels
        targets.append(d)
    return images, targets

def slice_output(output: dict, confidence_threshold: float = wandb.config['confidence_threshold']) -> dict:
    """
    this method is responsible for validating models output w.r.t confidence_threshold defined above.
    It accepts an output dictionary from model, namely {'boxes':[], 'labels':[], 'scores':[]}
    It returns a dictionary sliced to items with score above confidence_threshold
    """

    num_valid_elements = np.sum(np.array(output['scores']) >= confidence_threshold)
    # temporary option to make sure, that it returns at least one element, although it should probably be fixed,
    # should there be any frames where there is no starfish
    if num_valid_elements == 0:
        num_valid_elements = 1
    res = {}
    for key, value in output.items():
        res[key] = value[:num_valid_elements]
    return res

{'boxes': [[601.2649, 390.7638, 653.7231, 436.5513],
  [541.5198, 465.4016, 588.3859, 511.2046],
  [426.1378, 628.664, 485.297, 695.2532]],
 'labels': [1, 1, 1],
 'scores': [0.8809, 0.8313, 0.7175]}

In [57]:
torch.manual_seed(23)

# IF YOU WANT TO RUN PROPER MODEL LEARNING, MAKE SURE TO CHANGE DATASET SIZES

dataset = StarfishDataset()
train_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# extract only small part of the data for faster learning / testing process
train_size = int(0.8 * len(train_dataset))
test_size = len(train_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(train_dataset, [train_size, test_size])


print('Train dataset: {} instances, test dataset: {}'.format(len(train_dataset), len(test_dataset)))


train_dataloader = DataLoader(
    train_dataset, batch_size=wandb.config['batch_size'], shuffle=False, num_workers=1, collate_fn = collate_fn)
test_dataloader = DataLoader(
    test_dataset, batch_size=wandb.config['batch_size'], shuffle=False, num_workers=1,  collate_fn = collate_fn)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
cpu = torch.device('cpu')
print('Used device: {}'.format(device))

num_classes = 2  # starfish and not starfish I guess

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=wandb.config['learning_rate'], momentum=wandb.config['momentum'], weight_decay=wandb.config['weight_decay'])
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

Train dataset: 392 instances, test dataset: 99
Used device: cuda


In [None]:
# https://pytorch.org/vision/stable/models.html#runtime-characteristics see Faster R-CNN for the details of this model, what it requires, returns, etc

# https://github.com/pytorch/vision/blob/main/references/detection/engine.py probably see training and eval loops here

wandb.init(project="great-barrier-reef", entity="ap-wt", config = wandb.config)
for e in tqdm(range(wandb.config['epochs'])):
    print('\n')
        
    model.train()

    for idx, (images, targets) in enumerate(train_dataloader):

        images = list(image.to(device) for image in images)

        for d in targets:
            d['boxes'] = d['boxes'].to(device)
            d['labels'] = d['labels'].to(device)

        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
    model.eval()
    with torch.no_grad():
        for idx, (images, targets) in enumerate(test_dataloader):

            images = list(image.to(device) for image in images)
            predictions = model(images)
            outputs = [{k: v.to(cpu) for k, v in t.items()} for t in predictions]

            # TODO: add some comparison with 'targets' perhaps
            # TODO: any loss functions that is more reliable than this ? idk
            outputs = [slice_output(out) for out in outputs]
            metric = MeanAveragePrecision()
            metric.update(outputs, targets)
            metrics = metric.compute()
            if idx % 100 == 0:
                wandb.log({'MAP':metrics['map'], 'MAR_1':metrics['mar_1']})
        


    optimizer.step()
wandb.finish()

In [None]:
torch.save(model.state_dict(), 'models/FastRCNN.pt')