## Introductin
* Reference kernel => https://www.kaggle.com/sovitrath/pytorch-starter-faster-rcnn-train?scriptVersionId=38399463

In [None]:
# !pip install --upgrade torch torchvision torchaudio

In [None]:
!pip install efficientnet_pytorch

In [None]:
import torch

torch.__version__

## All Imports

In [None]:
import pandas as pd
import numpy as np
import cv2
import os
import re
import albumentations as A
import torch
import torchvision
import time

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler
from PIL import Image
from albumentations.pytorch.transforms import ToTensorV2
from matplotlib import pyplot as plt
from tqdm import tqdm

import matplotlib 
matplotlib.style.use('ggplot')

## Constant Paths

In [None]:
DIR_INPUT = '../input/global-wheat-detection'
DIR_TRAIN = f"{DIR_INPUT}/train"
DIR_TEST = f"{DIR_INPUT}/test"

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(DEVICE)

## Create Model

In [None]:
# # load the model
# model = torchvision.models.resnet18(pretrained=True).eval()
# # hook the feature extractor
# # https://github.com/zhoubolei/CAM/blob/master/pytorch_CAM.py
# features_blobs = []
# def hook_feature(module, input, output):
#     features_blobs.append(output.data.cpu().numpy())
    
# new_model = model._modules.get('layer4').register_forward_hook(hook_feature)
# print(new_model)

In [None]:
"""
model.py

We will create a FasterRCNN object detector with EfficientNet backbone for custom training.
Reference: https://github.com/lukemelas/EfficientNet-PyTorch
"""
def create_model():
    from efficientnet_pytorch import EfficientNet
#     backbone = EfficientNet.from_pretrained('efficientnet-b0')
#     print(backbone)
    conv_stem = torch.nn.Sequential(EfficientNet.from_pretrained('efficientnet-b0')._conv_stem)
    bn = torch.nn.Sequential(EfficientNet.from_pretrained('efficientnet-b0')._bn0)
    blocks = torch.nn.Sequential(*EfficientNet.from_pretrained('efficientnet-b0')._blocks)
    conv_head = torch.nn.Sequential(EfficientNet.from_pretrained('efficientnet-b0')._conv_head)
#     conv_head.out_channels = 1280
    backbone = torch.nn.Sequential(conv_stem, bn, blocks, conv_head)
    backbone.out_channels = 1280
    print(backbone)
    # FasterRCNN needs to know the number of
    # output channels in a backbone. For EfficientNetB0, it's 1280
    # so we need to add it here
#     backbone.out_channels = 1280
    print('-'*70)
#     print(backbone)
    # let's make the RPN generate 5 x 3 anchors per spatial
    # location, with 5 different sizes and 3 different aspect
    # ratios. We have a Tuple[Tuple[int]] because each feature
    # map could potentially have different sizes and
    # aspect ratios
    anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
                                       aspect_ratios=((0.5, 1.0, 2.0),))
    
    # let's define what are the feature maps that we will
    # use to perform the region of interest cropping, as well as
    # the size of the crop after rescaling.
    # if your backbone returns a Tensor, featmap_names is expected to
    # be [0]. More generally, the backbone should return an
    # OrderedDict[Tensor], and in featmap_names you can choose which
    # feature maps to use.
    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
                                                    output_size=7, 
                                                    sampling_ratio=2)
    # put everything together
    model = FasterRCNN(backbone, 
                       num_classes=2, 
                       rpn_anchor_generator=anchor_generator,
                       box_roi_pool=roi_pooler) 
    return model

model = create_model()
print('-'*50)
# print(model)

## Prepare Proper DataFrame

In [None]:
train_df = pd.read_csv(f"{DIR_INPUT}/train.csv")
print(train_df.shape)
train_df.head()

In [None]:
train_df['x'] = -1
train_df['y'] = -1
train_df['w'] = -1
train_df['h'] = -1

def expand_bbox(x):
    r = np.array(re.findall("([0-9]+[.]?[0-9]*)", x))
    if len(r) == 0:
        r = [-1, -1, -1, -1]
    return r

train_df[['x', 'y', 'w', 'h']] = np.stack(train_df['bbox'].apply(lambda x: expand_bbox(x)))
train_df.drop(columns=['bbox'], inplace=True)
train_df['x'] = train_df['x'].astype(np.float)
train_df['y'] = train_df['y'].astype(np.float)
train_df['w'] = train_df['w'].astype(np.float)
train_df['h'] = train_df['h'].astype(np.float)

train_df.head()

In [None]:
unique_image_ids = train_df['image_id'].unique()
print(f"Uninque image IDs: {len(unique_image_ids)}")
valid_ids = unique_image_ids[-665:]
train_ids = unique_image_ids[:-665]
print(f"Unqiue image IDs for training: {len(train_ids)}")
print(f"Unqiue image IDs for validation: {len(valid_ids)}")

In [None]:
valid_df = train_df[train_df['image_id'].isin(valid_ids)]
train_df = train_df[train_df['image_id'].isin(train_ids)]

print(f"Total training annotation instances: {len(train_df)}")
print(f"Total validation annotation instances: {len(valid_df)}")

In [None]:
train_df.head(7)

In [None]:
valid_df.head(7)

## Prepare Dataset

In [None]:
class WheatDataset(Dataset):
    def __init__(self, dataframe, image_dir, transforms=None):
        super().__init__()
        
        self.image_ids = dataframe['image_id'].unique()
        self.df = dataframe
        self.image_dir = image_dir
        self.transforms = transforms
        
    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
        records = self.df[self.df['image_id'] == image_id]
        
        image = cv2.imread(f"{self.image_dir}/{image_id}.jpg", cv2.IMREAD_COLOR).astype(np.float32)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image /= 255.0
        
        boxes = records[['x', 'y', 'w', 'h']].values
        boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
        boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
        
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        area = torch.as_tensor(area, dtype=torch.float32)
        
        # as there is only one class
        labels = torch.ones((records.shape[0], ), dtype=torch.int64)
        
        # no crowd instances
        iscrowd = torch.zeros((records.shape[0], ), dtype=torch.int64)
        
        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['image_id'] = torch.tensor([index])
        target['area'] = area
        target['iscrowd'] = iscrowd
        
        if self.transforms:
            sample = {
                'image': image,
                'bboxes': target['boxes'],
                'labels': labels
            }
            sample = self.transforms(**sample)
            image = sample['image']
            
            target['boxes'] = torch.stack(tuple(map(torch.FloatTensor, zip(*sample['bboxes'])))).permute(1, 0)
            
            return image, target, image_id
    def __len__(self) -> int:
        return self.image_ids.shape[0]

## Transforms

In [None]:
# define the training tranforms
def get_train_transform():
    return A.Compose([
        A.Flip(0.5),
        A.RandomRotate90(0.5),
        A.MotionBlur(p=0.2),
        A.MedianBlur(blur_limit=3, p=0.1),
        A.Blur(blur_limit=3, p=0.1),
        ToTensorV2(p=1.0),
    ], bbox_params={
        'format': 'pascal_voc',
        'label_fields': ['labels']
    })

# define the validation transforms
def get_valid_transform():
    return A.Compose([
        ToTensorV2(p=1.0),
    ], bbox_params={
        'format': 'pascal_voc', 
        'label_fields': ['labels']
    })

## Utilities and Helper Functions

In [None]:
class Averager:
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0
        
    def send(self, value):
        self.current_total += value
        self.iterations += 1
    
    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations
    
    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0

In [None]:
BATCH_SIZE = 8

def collate_fn(batch):
    return tuple(zip(*batch))

train_dataset = WheatDataset(train_df, DIR_TRAIN, get_train_transform())
valid_dataset = WheatDataset(valid_df, DIR_TRAIN, get_valid_transform())

train_data_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn
)

valid_data_loader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn
)

## Sample Visualization

In [None]:
images, targets, image_ids = next(iter(train_data_loader))
images = list(image.to(DEVICE) for image in images)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

In [None]:
for i in range(2):
    boxes = targets[i]['boxes'].cpu().numpy().astype(np.int32)
    sample = images[i].permute(1, 2, 0).cpu().numpy()
    plt.figure(figsize=(12, 9))
    sample = cv2.cvtColor(sample, cv2.COLOR_RGB2BGR)
    for box in boxes:
#         print(box)
        cv2.rectangle(sample,  # the image is in RGB, convert to BGR for cv2 annotations
                      (box[0], box[1]),
                      (box[2], box[3]),
                      (0, 0, 255), 3)
    plt.imshow(sample[:, :, ::-1])
    plt.axis('off')

## Training

In [None]:
model = model.to(DEVICE)

In [None]:
# print(model)

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
lr_scheduler = None

In [None]:
NUM_EPOCHS = 55

In [None]:
train_loss_hist = Averager()
val_loss_hist = Averager()
train_itr = 1
val_itr = 1
train_loss_list = []
val_loss_list = []

In [None]:
def train(train_data_loader):
    global train_itr
    global train_loss_list
    for i, data in enumerate(train_data_loader):
        images, targets, image_ids = data
        
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        train_loss_list.append(loss_value)

        train_loss_hist.send(loss_value)

        optimizer.zero_grad()

        losses.backward()
        optimizer.step()

        if train_itr % 50 == 0:
            print(f"Training iteration #{train_itr} loss: {loss_value}")

        train_itr += 1
    
    # update the learning rate
    if lr_scheduler is not None:
        lr_scheduler.step()
    return train_loss_list

In [None]:
def validate(valid_data_loader):
    global val_itr
    global val_loss_list
    for i, data in enumerate(valid_data_loader):
        images, targets, image_ids = data
        
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
        
        with torch.no_grad():
            loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        val_loss_list.append(loss_value)

        val_loss_hist.send(loss_value)

#         if itr % 50 == 0:
#             print(f"Validation iteration #{itr} loss: {loss_value}")

        val_itr += 1
    
    # update the learning rate
#     if lr_scheduler is not None:
#         lr_scheduler.step()
    return val_loss_list

In [None]:
for epoch in range(NUM_EPOCHS):
    start = time.time()
    train_loss_hist.reset()
    val_loss_hist.reset()
    train_loss = train(train_data_loader)
    val_loss = validate(valid_data_loader)
    print(f"Epoch #{epoch} train loss: {train_loss_hist.value}")   
    print(f"Epoch #{epoch} validation loss: {val_loss_hist.value}")   
    end = time.time()
    print(f"Took {(end - start) / 60} minutes for epoch {epoch}")
    print('SAVING MODEL...')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f"fasterrcnn_efficientnetb0_{epoch+1}.pth")
    print('SAVING COMPLETE...\n')
    
    
    plt.plot(val_loss, color='red')
    plt.plot(train_loss, color='blue', alpha=0.5)
    plt.xlabel('iterations')
    plt.ylabel('loss')
    plt.savefig(f"loss_{epoch+1}.png")
    if epoch % 5 == 0:
        plt.show()
    plt.close