# Import libraries

In [None]:
import os
import gc
import cv2
import time
import copy
import json
import random
import collections
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold

import torch
import torchvision
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [None]:
def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_all_seeds(42)

In [None]:
TRAIN_CSV = "../input/tensorflow-great-barrier-reef/train.csv"
IMAGE_PATH = "../input/tensorflow-great-barrier-reef/train_images/"
GEN_PATH = "../input/funie-gan1/ganpic/"

# Configuration

In [None]:
WIDTH = 1280
HEIGHT = 720

NUM_CLASSES = 2

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(DEVICE)

RESNET_MEAN = (0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)

RESIZE = None

BATCH_SIZE = 4

GEN = True

# No changes tried with the optimizer yet.
MOMENTUM = 0.9
LEARNING_RATE = 0.01
WEIGHT_DECAY = 0.0005

# Normalize to resnet mean and std if True.
NORMALIZE = False 


# Use a StepLR scheduler if True. Not tried yet.
USE_SCHEDULER = True

# Amount of epochs
NUM_EPOCHS = 25

DEBUG = False

# Data preprocessing

In [None]:
train = pd.read_csv(TRAIN_CSV)
train['image_path'] = train['image_id'].apply(lambda x: IMAGE_PATH+'video_'+x.split('-')[0]+'/'+x.split('-')[1]+'.jpg')
train['annotations'] = train['annotations'].apply(lambda x: list(eval(x)))
train['num_boxes'] = train['annotations'].apply(lambda x: len(x))
image_df = train[train['num_boxes'] != 0]
image_df.reset_index(drop=True, inplace=True)
image_df['Index'] = image_df.index
image_df['GAN_path'] = image_df['Index'].apply(lambda x: GEN_PATH + f'{x}.png')

del train

In [None]:
skf = StratifiedKFold(n_splits=5, shuffle=True)
for fold, (train_idx, val_idx) in enumerate(skf.split(image_df, image_df["video_id"])):
    image_df.loc[val_idx, 'fold'] = fold

In [None]:
# These are slight redefinitions of torch.transformation classes
# The difference is that they handle the target and the mask
# Copied from Abishek
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class VerticalFlip(object):
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
        return image, target

class HorizontalFlip(object):
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
        return image, target

class Normalize(object):
    def __call__(self, image, target):
        image = F.normalize(image, RESNET_MEAN, RESNET_STD)
        return image, target

class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class AdBright(object):
    def __call__(self, image, target):
        image = F.adjust_brightness(image, brightness_factor=1.4)
        return image, target
    

def get_transform(train):
    transforms = [ToTensor()]
    if NORMALIZE:
        transforms.append(Normalize())
    
    # Data augmentation for train
    if train: 
        transforms.append(HorizontalFlip(0.5))
        transforms.append(VerticalFlip(0.5))
        #transforms.append(AdBright())

    return Compose(transforms)

In [None]:
class GBRDataset(Dataset):
    def __init__(self, df, transforms=None, resize=None):
        self.transforms = transforms
        self.df = df
        self.resize = resize
        if self.resize is not None:
            self.height = int(HEIGHT * resize)
            self.width = int(WIDTH * resize)
        else:
            self.height = HEIGHT
            self.width = WIDTH
        
        self.image_info = collections.defaultdict(dict)
        for index, row in df.iterrows():
            self.image_info[index] = {
                    'image_id': row['image_id'],
                    'image_path': row['image_path'],
                    'annotations': row["annotations"],
                    'GAN_path': row["GAN_path"]
                    }
    
    def get_box(self, item):
        ''' Get the bounding box of a given mask '''
        xmin = item['x']
        xmax = xmin + item['width']
        ymin = item['y']
        ymax = ymin + item['height']
        return [xmin, ymin, xmax, ymax]
    
    def resize_boxes(self, boxes, resize):
        xmin, ymin, xmax, ymax = boxes.unbind(1)
        xmin = xmin * resize
        xmax = xmax * resize
        ymin = ymin * resize
        ymax = ymax * resize
        return torch.stack((xmin, ymin, xmax, ymax), dim=1)

    def __getitem__(self, idx):
        ''' Get the image and the target'''
        if GEN:
            img_path = self.image_info[idx]["GAN_path"]
            img = cv2.imread(img_path)
            img = img.astype(np.float32)
        else:
            img_path = self.image_info[idx]["image_path"]
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img /= 255.0
        
        info = self.image_info[idx]

        n_objects = len(info['annotations'])
        boxes = [self.get_box(item) for item in info['annotations']]
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        
        if self.resize is not None:
            img = cv2.resize(img, (self.width, self.height), interpolation = cv2.INTER_LINEAR)
            boxes = self.resize_boxes(boxes, self.resize)

        # dummy labels
        labels = [1 for _ in range(n_objects)]
        labels = torch.as_tensor(labels, dtype=torch.int64)
        
        image_id = torch.tensor([idx])
        
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        area = torch.as_tensor(area, dtype=torch.float32)
        
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        # This is the required target for the Faster R-CNN
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.image_info)

In [None]:
def get_box(item):
    xmin = item['x']
    xmax = xmin + item['width']
    ymin = item['y']
    ymax = ymin + item['height']
    return [xmin, ymin, xmax, ymax]

def plot_from_df(idx):
    img_path = image_df.iloc[idx]["image_path"]
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    boxes = [get_box(item) for item in image_df.iloc[idx]['annotations']]
    for i in boxes:
        cv2.rectangle(img, (int(i[0]),int(i[1])), (int(i[2]),int(i[3])), (255,0,0), thickness=2)
    plt.figure(figsize=(10,10))
    plt.imshow(img)
    
def gen_from_df(idx):
    img_path = image_df.iloc[idx]["GAN_path"]
    img = cv2.imread(img_path)
    boxes = [get_box(item) for item in image_df.iloc[idx]['annotations']]
    for i in boxes:
        cv2.rectangle(img, (int(i[0]),int(i[1])), (int(i[2]),int(i[3])), (255,0,0), thickness=2)
    plt.figure(figsize=(10,10))
    plt.imshow(img)

In [None]:
t = GBRDataset(image_df[image_df['fold'] != 2].reset_index(drop=True), resize=None, transforms=get_transform(train=False))

In [None]:
plot_from_df(2500)
gen_from_df(2500)

In [None]:
del t
gc.collect()

In [None]:
def collate_fn(batch):
        return tuple(zip(*batch))

def prepare_loaders(fold):   
    train_df = image_df[image_df.fold != fold].reset_index(drop=True)
    valid_df = image_df[image_df.fold == fold].reset_index(drop=True)
    
    if DEBUG:
        train_dataset = GBRDataset(train_df[:40], resize=RESIZE, transforms=get_transform(train=True))
        valid_dataset = GBRDataset(valid_df[:40], resize=RESIZE, transforms=get_transform(train=True))
    else:
        train_dataset = GBRDataset(train_df, resize=RESIZE, transforms=get_transform(train=True))
        valid_dataset = GBRDataset(valid_df, resize=RESIZE, transforms=get_transform(train=True))

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                              num_workers=2, shuffle=False, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, 
                              num_workers=2, shuffle=False, collate_fn=collate_fn)
    print(f'Train_df has {len(train_loader)} rows')
    print(f'Valid_df has {len(valid_loader)} rows')
    
    return train_loader, valid_loader

# Model

In [None]:
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/fasterrcnn/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth

In [None]:
def get_model():
    
    if NORMALIZE:
        model =  torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True,
                                                                   image_mean=RESNET_MEAN, 
                                                                   image_std=RESNET_STD)
    else:
        model =  torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)

    return model


# Get the Faster R-CNN model
# The model does classification, bounding boxes and MASKs for individuals, all at the same time
# We only care about MASKS
model = get_model()
model.to(DEVICE)

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, momentum = MOMENTUM)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.85)

In [None]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    train_loss = []
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, (images, targets) in pbar:         
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] 
        
        loss_dict = model(images, targets)
            
        losses = sum(loss for loss in loss_dict.values())
        train_loss.append(losses.item())
        
        optimizer.zero_grad() # zero the parameter gradients
        losses.backward()
        optimizer.step()

    if USE_SCHEDULER:
        scheduler.step()    
        
    mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
    
    torch.cuda.empty_cache()
    gc.collect()
    
    return np.mean(train_loss)

In [None]:
def valid_one_epoch(model, dataloader, device, epoch): 
    valid_loss = []
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid ')
    with torch.no_grad():
        for step, (images, targets) in pbar:         
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
            
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            valid_loss.append(losses.item())
        
    mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        
    torch.cuda.empty_cache()
    gc.collect()
    
    return np.mean(valid_loss)

In [None]:
def run_training(model, optimizer, scheduler, device, num_epochs):
    
    if torch.cuda.is_available():
        print("cuda: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    history = {}
    best_loss = np.inf
    best_epoch = -1
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        print(f'Epoch {epoch}/{num_epochs}', end='')
        
        train_loss = train_one_epoch(model, optimizer, scheduler, 
                                                            dataloader=train_loader, device=DEVICE, epoch=epoch)
        val_loss = valid_one_epoch(model, valid_loader, device=DEVICE, epoch=epoch)
        
        if len(history) == 0:
            history['train_loss'], history['valid_loss'] = [train_loss], [val_loss]
        else:
            history['train_loss'].append(train_loss)
            history['valid_loss'].append(val_loss)
        
        if val_loss <= best_loss:
            best_loss = val_loss
            best_epoch = epoch
        
        # deep copy the model
        PATH = f"best_epoch-{epoch:02d}.bin"
        print(PATH)
        torch.save(model.state_dict(), PATH)
        # Save a model file from the current directory
        print(f"Model Saved")
            
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Score: {:.4f}".format(best_loss))
    print("Best Epoch: {:3d}".format(best_epoch))
    
    return model, history

In [None]:
train_loader, valid_loader = prepare_loaders(fold = 4)
model, history = run_training(model, optimizer, lr_scheduler,device=DEVICE, num_epochs=NUM_EPOCHS if not DEBUG else 10)


# Loss plot (train vs valid)

In [None]:
plt.plot(history['train_loss'], color = 'b', label = 'train_loss')
plt.plot(history['valid_loss'], color = 'r', label = 'valid_loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()

In [None]:
history['valid_loss']