In [None]:
import random as rnd
import numpy as np
import cv2 as cv
import torch


__all__ = [
    'Mosaic',
    'MixUp',
    'RandomAugmentation'
]


class Mosaic:
    m_num_required_elements = 5

    @staticmethod
    def get_area(box: np.ndarray) -> float:
        """
        ...

        Args:
            box (np.ndarray): coordinates of the box    [pascal-voc form]

        Returns:
            (float) Area of the box.
        """
        return (box[2] - box[0]) * (box[3] - box[1])

    @staticmethod
    def get_overlap_percentage(b_1: np.ndarray, b_2: np.ndarray) -> tuple:
        """
        Calculates the percentage of the overlap area with respect to box 1 and the overlap 'directions'.

        Args:
            b_1 (np.ndarray): coordinates of the first box  [pascal-voc form]
            b_2 (np.ndarray): coordinates of the second box [pascal-voc form]

        Returns:
            (float) Percentage of the overlap area with respect to box 1.
            (tuple) Overlap 'directions'.       (up, down, left, right)
        """

        # In case a bbox coordinate is on the edge of an image (--> make it fit in the array)
        if b_1[2] > 1023:
            b_1[2] = 1023
        if b_1[3] > 1023:
            b_1[3] = 1023

        # Calculate overlap area
        d_x = min(b_1[2], b_2[2]) - max(b_1[0], b_2[0])
        d_y = min(b_1[3], b_2[3]) - max(b_1[1], b_2[1])

        if d_x < 0 != d_y < 0:
            return 0.0, (False, False, False, False)  # not overlapping

        overlap_area = d_x * d_y
        box_1_area = Mosaic.get_area(box=b_1)

        # Calculate overlap 'directions'        (up, down, left, right)
        overlap_directions = (
            (min(b_1[1], b_2[1]) == b_1[1]),
            (max(b_1[3], b_2[3]) == b_1[3]),
            (min(b_1[0], b_2[0]) == b_1[0]),
            (max(b_1[2], b_2[2]) == b_1[2])
        )

        return (overlap_area / box_1_area), overlap_directions

    @staticmethod
    def apply(imgs: tuple, bboxes: tuple, p: float, p_mixup: float = 0.8) -> (np.ndarray, list):
        if rnd.random() > p:
            i = rnd.randint(0, 3)
            return imgs[i], bboxes[i]

        h, w, c = imgs[0].shape
        h -= 1
        w -= 1
        shape = (h, w)

        # Apply random MixUp
        img_0, img_1, img_2, img_3, img_4 = imgs
        bboxes_0, bboxes_1, bboxes_2, bboxes_3, bboxes_4 = bboxes
        img_0, bboxes_0 = MixUp.apply(imgs=(img_0, img_1), bboxes=(bboxes_0, bboxes_1), p=p_mixup, p_chaos=0.0)
        imgs = (img_0, img_2, img_3, img_4)
        bboxes = (bboxes_0, bboxes_2, bboxes_3, bboxes_4)

        rnd_x_scale = rnd.uniform(0.2, 0.8)
        rnd_y_scale = rnd.uniform(0.2, 0.8)

        img = Mosaic.apply_to_imgs(imgs, shape, rnd_x_scale, rnd_y_scale)
        bboxes = Mosaic.apply_to_bboxes(bboxes, shape, rnd_x_scale, rnd_y_scale)

        return img, bboxes

    @staticmethod
    def apply_to_imgs(imgs: tuple, shape: tuple, rnd_x_scale: float, rnd_y_scale: float) -> np.ndarray:
        img_1, img_2, img_3, img_4 = imgs
        h, w = shape

        # Concat the four input images to upper and lower image pairs
        img_lower = np.concatenate((img_3[int(rnd_y_scale*h):h, :int(rnd_x_scale*w)],
                                    img_4[int(rnd_y_scale*h):h, int(rnd_x_scale*w):w]),
                                   axis=1)
        img_upper = np.concatenate((img_1[:int(rnd_y_scale*h), :int(rnd_x_scale*w)],
                                    img_2[:int(rnd_y_scale*h), int(rnd_x_scale*w):w]),
                                   axis=1)

        # Concat upper and lower image pairs to final image
        img = np.concatenate((img_upper, img_lower), axis=0)

        return img

    @staticmethod
    def apply_to_bboxes(bboxes: tuple, shape: tuple, rnd_x_scale: float, rnd_y_scale: float) -> list:
        h, w = shape
        d_x = int(rnd_x_scale*w)
        d_y = int(rnd_y_scale*h)
        img_areas = [[0, 0, d_x, d_y], [d_x, 0, w, d_y], [0, d_y, d_x, h], [d_x, d_y, w, h]]    # pascal-voc form; [left_upper, right_upper, left_lower, right_lower]
        overlap_threshold = 0.0
        p_area_threshold = 0.25

        keep_bboxes = []
        for bboxes_i in range(len(bboxes)):
            img_area = img_areas[bboxes_i]
            for bbox in bboxes[bboxes_i]:
                op, overlap_directions = Mosaic.get_overlap_percentage(b_1=bbox, b_2=img_area)
                if op > overlap_threshold:
                    area_0 = Mosaic.get_area(box=bbox)
                    if op < 1.0:
                        x_mode = bboxes_i % 2       # 0 for upper_left and lower_left; 1 for upper_right and lower_right
                        y_mode = int(bboxes_i > 1)  # 0 for upper_left and upper_right; 1 for lower_left and lower_right
                        if (overlap_directions[2] and x_mode == 1) or (overlap_directions[3] and x_mode == 0):
                            bbox[(2-((bboxes_i % 2)*2))] = d_x
                        if (overlap_directions[0] and y_mode == 1) or (overlap_directions[1] and y_mode == 0):
                            bbox[(3-(int(bboxes_i > 1)*2))] = d_y
                        p_area = Mosaic.get_area(box=bbox) / area_0
                        if p_area > p_area_threshold:
                            keep_bboxes.append((*bbox.tolist(), ))
                    else:
                        keep_bboxes.append((*bbox.tolist(), ))

        return keep_bboxes


class MixUp:
    m_num_required_elements = 3

    @staticmethod
    def apply(imgs: tuple, bboxes: tuple, p: float, p_chaos=0.3) -> (np.ndarray, list):
        if rnd.random() > p:
            i = rnd.randint(0, 1)
            return imgs[i], bboxes[i]

        if rnd.random() > p_chaos:
            n_imgs = 2
        else:
            n_imgs = 3
        return_img = MixUp.apply_to_imgs(imgs=imgs, n_imgs=n_imgs)
        return_bboxes = MixUp.apply_to_bboxes(bboxes=bboxes, n_imgs=n_imgs)

        return return_img, return_bboxes

    @staticmethod
    def apply_to_imgs(imgs: tuple, n_imgs: int) -> np.ndarray:
        if n_imgs == 3:
            img_1, img_2, img_3 = imgs
        else:
            img_1, img_2 = imgs[ :2]

        #l = np.random.beta(a=16.0, b=16.0)
        l = 0.5
        new_img = cv.addWeighted(src1=img_1, alpha=l, src2=img_2, beta=(1-l), gamma=0.0)

        if n_imgs == 3:
            l = (l*2) / 3
            new_img = cv.addWeighted(src1=new_img, alpha=l, src2=img_3, beta=(1-l), gamma=0.0)

        return new_img

    @staticmethod
    def apply_to_bboxes(bboxes: tuple, n_imgs: int) -> list:
        if n_imgs == 3:
            bboxes_1, bboxes_2, bboxes_3 = bboxes
        else:
            bboxes_1, bboxes_2 = bboxes[ :2]

        if n_imgs == 3:
            cated_bboxes = torch.cat(tensors=(bboxes_1, bboxes_2, bboxes_3), dim=0)
        else:
            cated_bboxes = torch.cat(tensors=(bboxes_1, bboxes_2), dim=0)

        return cated_bboxes


class RandomAugmentation:
    m_augmenatation_distribution = [0.5, 0.5]
    m_augmentation_methods = [
        Mosaic,
        MixUp
    ]

    @staticmethod
    def get_random_augmentation(augmentation_distribution=m_augmenatation_distribution) -> object:
        augmentation_method = rnd.choices(population=RandomAugmentation.m_augmentation_methods,
                                          weights=augmentation_distribution, k=1)[0]

        return augmentation_method


In [None]:
import os
import pandas as pd
import torch
from ast import literal_eval
import cv2 as cv
import numpy as np
import random as rnd


class WheatDetectionDataset:
    def __init__(self, root, transforms, mode):
        self.root = root
        self.transforms = transforms
        self.mode = mode

        self.imgs = list(os.listdir(os.path.join(root, mode)))
        if mode is 'train':
            self.preds = pd.read_csv(os.path.join(root, 'train.csv'))

    def get_img(self, index):
        img_path = os.path.join(self.root, self.mode, self.imgs[index])
        img = cv.imread(img_path, cv.IMREAD_COLOR)
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32)
        #img = cv.GaussianBlur(img, ksize=(5, 5), sigmaX=0)
        img /= 255.0

        return img

    def __getitem__(self, index, mode=0):
        # Load image
        img = self.get_img(index=index)

        if self.mode == 'train':
            # Load prediction boxes
            img_boxes = self.preds.loc[self.preds['image_id'] == self.imgs[index].split('.')[0]]['bbox']
            num_boxes = len(img_boxes)
            boxes = []
            # Box coords in pascal-voc form
            for box in img_boxes:
                coords = literal_eval(box)
                x_min = float(coords[0])
                y_min = float(coords[1])
                x_max = float(x_min + coords[2])
                y_max = float(y_min + coords[3])
                boxes.append([x_min, y_min, x_max, y_max])
            img_boxes = torch.as_tensor(boxes)
            
            img_id = torch.tensor([index])
            
            target = {}
            target['boxes'] = img_boxes
            #target['labels'] = None
            target['image_id'] = img_id
            #target['area'] = None
            #target['iscrowd'] = None
            
            if self.transforms is not None and mode == 0:
                sample = {
                    'image': img,
                    'bboxes': target['boxes'],
                    'labels': None
                }
                
                # Apply own augmentation
                augmentation_method = RandomAugmentation.get_random_augmentation(augmentation_distribution=[0.5, 0.5])
                imgs = []
                bboxes = []
                for i in range(augmentation_method.m_num_required_elements):
                    apply_img, apply_target = self.__getitem__(index=rnd.randint(0, (len(self.imgs) - 1)), mode=1)
                    imgs.append(apply_img)
                    bboxes.append(apply_target['boxes'])
                imgs = tuple(imgs)
                bboxes = tuple(bboxes)
                
                sample['image'], sample['bboxes'] = augmentation_method.apply(
                    imgs=imgs,
                    bboxes=bboxes,
                    p=0.9
                )
                
                sample = self.transforms(**sample)
                img = sample['image']
                num_boxes = len(sample['bboxes'])
                target['boxes'] = torch.tensor(sample['bboxes'], dtype=torch.float32)

            try:
                area = (img_boxes[:, 3] - img_boxes[:, 1]) * (img_boxes[:, 2] - img_boxes[:, 0])
            except IndexError:
                area = 0

            # set up labels
            labels = torch.ones((num_boxes, ), dtype=torch.int64)

            is_crowd = torch.zeros((num_boxes, ), dtype=torch.int64)

            target['area'] = torch.as_tensor(area, dtype=torch.float32)
            target['labels'] = labels
            target['is_crowd'] = is_crowd

            return img, target

        else:
            if self.transforms is not None:
                img = self.transforms(img)

            return img, self.imgs[index]

    def __len__(self) -> int:
        return len(self.imgs)


In [None]:
# using: http://openaccess.thecvf.com/content_CVPR_2019/papers/He_Bag_of_Tricks_for_Image_Classification_with_Convolutional_Neural_Networks_CVPR_2019_paper.pdf

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
import torchvision.transforms as T
import torchvision
import torch
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import matplotlib.pyplot as plt
import cv2 as cv


def collate_fn(batch):
    return tuple(zip(*batch))


def get_transform(train):
    transforms = []
    if train:
        transforms = [T.ToPILImage()] + augmentation.get_augmentation(n=2)
    transforms.append(T.ToTensor())
    #transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
    return T.Compose(transforms)


def get_train_transform():
    return A.Compose([
        A.RandomContrast(p=0.1),
        A.RandomBrightness(p=0.1),
        A.RandomGamma(p=0.1),
        A.GaussianBlur(p=0.1),
        A.HueSaturationValue(p=0.1),
        ToTensorV2(p=1.0)
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})


# using: https://www.kaggle.com/maherdeebcv/test-pytorch-faster-r-cnn-with-resnet152-backbone
def fasterrcnn_resnet_fpn(pretrained=False, num_classes=2, pretrained_backbone=False, **kwargs):
    if pretrained:
        pretrained_backbone = False
    
    backbone = resnet_fpn_backbone(backbone_name='resnet152', pretrained=pretrained_backbone)
    model = FasterRCNN(backbone=backbone, num_classes=num_classes, **kwargs)

    return model


def get_model(num_classes):
    #model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    model = fasterrcnn_resnet_fpn(pretrained=False, pretrained_backbone=True)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model


def load_checkpoint(checkpoint_path, model, optim):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optim.load_state_dict(checkpoint['optimizer_state_dict'])

    return model, optim, checkpoint['epoch']


def main():
    device = torch.device('cuda')
    
    continue_training = True

    num_classes = 2     # 1 class (wheat) + background
    #dataset_train = WheatDetectionDataset(root='data', transforms=get_transform(train=True), mode='train')
    dataset_train = WheatDetectionDataset(root='../input/global-wheat-detection', transforms=get_train_transform(), mode='train')

    # define dataset loaders
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn
    )

    # get model
    model = get_model(num_classes).to(device)
    #model = torch.load('models/wheat_model_10.0.pth', map_location=device).train()
    print('Model loaded successfully.')

    i_lr = 0.005      # initial learning rate
    num_batches_warmup = 325
    params = [p for p in model.parameters() if p.requires_grad]
    optim = torch.optim.SGD(params=params, lr=i_lr, momentum=0.9, weight_decay=0.0005)
    
    start_epoch = 0
    
    if continue_training:
        model, optim, start_epoch = load_checkpoint(checkpoint_path='../input/wheat-detection-model/wheat_model_checkpoint_11.4.pth', 
                                                    model=model, optim=optim)
        print('Model checkpoint loaded successfully.')
    
    num_epochs = start_epoch + 3

    len_dataloader = len(data_loader_train)
    
    for epoch in range(start_epoch, num_epochs):
        model.train()
        i = 0

        loss_hist = []
        for imgs, targets in data_loader_train:
            i += 1

            imgs = [img.to(device) for img in imgs]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            try:
                """test_img = imgs[0].detach().permute(1, 2, 0).to('cpu').numpy()
                bboxes = targets[0]['boxes'].detach().to('cpu')
                for it, bbox in enumerate(bboxes):
                    bbox = bbox.numpy()
                    p_1 = (int(bbox[0]), int(bbox[1]))
                    p_2 = (int(bbox[2]), int(bbox[3]))
                    test_img = cv.rectangle(img=test_img, pt1=p_1, pt2=p_2, color=(0, 0, 255))
                #print(bboxes)
                #print(test_img.shape)
                cv.imshow('TEST_0', test_img)
                cv.waitKey()"""

                loss_dict = model(imgs, targets)
                losses = sum(loss for loss in loss_dict.values())

                optim.zero_grad()
                losses.backward()
                optim.step()

                loss_hist.append(losses.__float__())

                # adjust learning rate
                current_batch = (epoch * len_dataloader) + i
                if current_batch <= num_batches_warmup:
                    # learning rate warmup
                    # starting with a too big learning rate may result in something unwanted
                    lr = current_batch * (i_lr / num_batches_warmup)
                else:
                    # cosine learning rate decay
                    # (smoother than step learning rate decay)
                    lr = i_lr * 0.5 * (1 + np.cos(((current_batch-num_batches_warmup) * np.pi) / (len_dataloader * num_epochs)))
                for param_group in optim.param_groups:
                    param_group['lr'] = lr

                if i % 214 == 0:
                    print(f'Iteration: {i}/{len_dataloader}, Loss: {np.mean(loss_hist)}, Epoch: {epoch}')
                    loss_hist = []

            except IndexError:
                continue

    torch.save(model, '/kaggle/working/wheat_model_11.5.pth')
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optim.state_dict(),
    }, '/kaggle/working/wheat_model_checkpoint_11.5.pth')


main()
