<a href="https://colab.research.google.com/github/wtaisner/tensorflow-great-barrier-reef/blob/main/competition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
try:
    import wandb
except:
    !pip install wandb
    import wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33map-wt[0m (use `wandb login --relogin` to force relogin)


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import pandas as pd
import os
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torch.nn as nn
import torchvision
import ast
import torch
from tqdm import tqdm

import matplotlib.pyplot as plt
import matplotlib.patches as patches

# this should probably be changed to something smart, right?
KAGGLE_PATH_ANNOTATIONS = '/kaggle/input/tensorflow-great-barrier-reef/train.csv'
KAGGLE_PATH_IMG_DIR = '/kaggle/input/tensorflow-great-barrier-reef/train_images/'
LOCAL_PATH_ANNOTATIONS = 'data/train.csv'
LOCAL_PATH_IMG_DIR = 'data/train_images/'
COLAB_PATH_ANNOTATIONS = '/content/drive/MyDrive/data/train.csv'
COLAB_PATH_IMG_DIR = '/content/drive/MyDrive/data/train_images/'

wandb.config = {
  "learning_rate": 0.001,
  "epochs": 10,
  "batch_size": 2,
  "momentum": 0.9,
  "weight_decay": 0.0005
}

In [4]:
class StarfishDataset(Dataset):
    def __init__(self,
                 annotations_file=COLAB_PATH_ANNOTATIONS,
                 img_dir=COLAB_PATH_IMG_DIR
                 ):
        self.img_labels = pd.read_csv(annotations_file)
        self.annotated = self.img_labels[self.img_labels['annotations'] != '[]']  # get only annotated frames
        self.img_dir = img_dir

    def __len__(self):
        return len(self.annotated)

    def __getitem__(self, idx):
        image = read_image(os.path.join(self.img_dir, 'video_{}'.format(self.annotated.iloc[idx][0]),
                                        '{}.jpg'.format(self.annotated.iloc[idx][2])))
        min_image = image.min()
        max_image = image.max()
        # normalize image to 0-1 - required by torchvision
        image -= min_image
        image = torch.FloatTensor(image/max_image)
        # print(image.shape) # image shape has to be [C, H, W], it is :)
        labels = self.annotated.iloc[idx][-1]
        labels = ast.literal_eval(labels)
        coords = []
        for parsed_label in labels:
            x1, y1 = parsed_label['x'], parsed_label['y']
            x2, y2 = x1+parsed_label['width'], y1+parsed_label['height']
            coords.append([x1, y1, x2, y2])

            # fig, ax = plt.subplots()
            # ax.imshow(image.permute(1, 2, 0))
            # rect = patches.Rectangle((x1, y1), parsed_label['width'], parsed_label['height'], linewidth=1, edgecolor='r', facecolor='none')
            # ax.add_patch(rect)
            # plt.show()

        boxes = torch.FloatTensor(coords)
        labels = torch.LongTensor([1 for _ in range(len(coords))]) # label has to be integer, since we have only one label I coded it as 1 for simplicity
        # print(target[0].shape, target[1].shape)
        return image, boxes, labels

# dataset = StarfishDataset()
# dataset.__getitem__(0)


In [5]:
def collate_fn(batch):
    targets = []
    images = []
    for imgs, boxes, labels in batch:
        images.append(imgs)
        d = {}
        d['boxes'] = boxes
        d['labels'] = labels
        targets.append(d)
    return images, targets

In [6]:
torch.manual_seed(1)

dataset = StarfishDataset()
train_size = int(0.3 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# extract only small part of the data for faster learning / testing process
train_size = int(0.8 * len(train_dataset))
test_size = len(train_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(train_dataset, [train_size, test_size])


print('Train dataset: {} instances, test dataset: {}'.format(len(train_dataset), len(test_dataset)))


train_dataloader = DataLoader(
    train_dataset, batch_size=wandb.config['batch_size'], shuffle=False, num_workers=1, collate_fn = collate_fn)
test_dataloader = DataLoader(
    test_dataset, batch_size=wandb.config['batch_size'], shuffle=False, num_workers=1,  collate_fn = collate_fn)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
cpu = torch.device('cpu')
print('Used device: {}'.format(device))

num_classes = 2  # starfish and not starfish I guess

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=wandb.config['learning_rate'], momentum=wandb.config['momentum'], weight_decay=wandb.config['weight_decay'])
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

Train dataset: 1180 instances, test dataset: 295
Used device: cuda


In [7]:
# images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
# labels = torch.randint(1, 91, (4, 11))
# images = list(image for image in images)
# targets = []
# for i in range(len(images)):
#     d = {}
#     d['boxes'] = boxes[i]
#     d['labels'] = labels[i]
#     targets.append(d)

# print('Images {} \n Boxes {} \n labels {} \n targets {} \n'.format(len(images), boxes.shape, labels.shape, targets))

In [None]:
# https://pytorch.org/vision/stable/models.html#runtime-characteristics see Faster R-CNN for the details of this model, what it requires, returns, etc

# https://github.com/pytorch/vision/blob/main/references/detection/engine.py probably see training and eval loops here

wandb.init(project="great-barrier-reef", entity="ap-wt", config = wandb.config)

for e in tqdm(range(wandb.config['epochs'])):
# for e in tqdm(range(2)):
    print('\n')
    
    model.train()

    for idx, (images, targets) in enumerate(train_dataloader):

        images = list(image.to(device) for image in images)

        for d in targets:
            d['boxes'] = d['boxes'].to(device)
            d['labels'] = d['labels'].to(device)

        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        wandb.log({'train_loss':loss})
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    model.eval()
    print('Reached eval')
    with torch.no_grad():
        groundtruth, predictions = None, None
        for idx, (images, targets) in enumerate(test_dataloader):

            images = list(image.to(device) for image in images)
            predictions = model(images)
            outputs = [{k: v.to(cpu) for k, v in t.items()} for t in predictions]
            # TODO: add some comparison with 'targets' perhaps
            if idx % 50 == 0:
                print(outputs)
            wandb.log({"adam": outputs})

    optimizer.step()
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33map-wt[0m (use `wandb login --relogin` to force relogin)


  0%|          | 0/10 [00:00<?, ?it/s]





  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Reached eval
[{'boxes': tensor([[ 597.3156,  378.6286,  663.7199,  439.7719],
        [ 790.4837,  611.1317,  874.0908,  681.3749],
        [ 497.7653,  391.9684,  560.9460,  459.7655],
        [ 195.9283,  208.9145,  246.7803,  250.2139],
        [ 547.5464,  339.0098,  608.0395,  393.8109],
        [ 234.7136,  337.5928,  294.2419,  391.3565],
        [ 302.3225,  406.7838,  340.5309,  447.2247],
        [ 577.9064,  554.4933,  633.1340,  623.2035],
        [ 519.1411,  340.0183,  598.9832,  405.3338],
        [ 451.2263,   94.3412,  492.9230,  137.1499],
        [ 501.3882,  356.9796,  578.8117,  465.8529],
        [ 325.3591,  554.5240,  397.9320,  628.4894],
        [ 510.0660,  346.2499,  565.0208,  400.6040],
        [ 402.8376,   61.2239,  442.3608,   97.9702],
        [ 202.4045,  205.9955,  247.3283,  233.1169],
        [ 534.7263,  338.5306,  608.8254,  446.9872],
        [ 343.2345,  568.6543,  390.0980,  615.6118],
        [ 531.4572,  395.3956,  561.9553,  447.1858],
    

 10%|█         | 1/10 [06:37<59:37, 397.54s/it]



