In [None]:
!pip install wwf timm -qqq

In [None]:
!pip install --upgrade fastai

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
import random
import torch
import torchvision
from torchvision import transforms
import albumentations as A
from fastai.vision.all import *
from fastai import *
from wwf.vision.timm import *
import timm

In [None]:
def set_seed(seed=42):
    os.environ['PYTHONHASHSEED']=str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [None]:
set_seed()

In [None]:
class CFG:
    model = 'resnext101_32x8d' #Taken ViT, Place your model here!
    train_bs = 16
    valid_bs = 32
    image_size = 224
    tta = 5
    epochs = 10
    lr = 1e-3
    fp16 = True
    fp32 = True
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [None]:
train_df = pd.read_csv('../input/happy-whale-and-dolphin/train.csv')
train_df.head()

In [None]:
train_df['species'].value_counts()

In [None]:
DIR = "../input/happy-whale-and-dolphin/train_images"
image_path = f'{DIR}/00103cbe9d25ce.jpg'

In [None]:
random_img = cv2.imread(image_path)
plt.imshow(random_img)

In [None]:
# Thanks to https://www.kaggle.com/khoongweihao/insect-augmentation-et-al
albumentation_list = [A.RandomSunFlare(p=1), 
                      A.RandomFog(p=1), 
                      A.RandomBrightness(p=1),
                      A.RandomCrop(p=1,height = 512, width = 512), 
                      A.Rotate(p=1, limit=90),
                      A.RGBShift(p=1), 
                      A.RandomSnow(p=1),
                      A.HorizontalFlip(p=1), 
                      A.VerticalFlip(p=1), 
                      A.RandomContrast(limit = 0.5,p = 1),
                      A.HueSaturationValue(p=1,hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=50),
                      A.Cutout(p=1),
                      A.Transpose(p=1), 
                      A.JpegCompression(p=1),
                      A.CoarseDropout(p=1),
                      A.IAAAdditiveGaussianNoise(loc=0, scale=(2.5500000000000003, 12.75), per_channel=False, p=1),
                      A.IAAAffine(scale=1.0, translate_percent=None, translate_px=None, rotate=0.0, shear=0.0, order=1, cval=0, mode='reflect', p=1),
                      A.IAAAffine(rotate=90., p=1),
                      A.IAAAffine(rotate=180., p=1)]

In [None]:
img_matrix_list = []
bboxes_list = []
for aug_type in albumentation_list:
    img = aug_type(image = random_img)['image']
    img_matrix_list.append(img)

img_matrix_list.insert(0,random_img)    

titles_list = ["Original","RandomSunFlare","RandomFog","RandomBrightness",
               "RandomCrop","Rotate", "RGBShift", "RandomSnow","HorizontalFlip", "VerticalFlip", "RandomContrast","HSV",
               "Cutout","Transpose","JpegCompression","CoarseDropout","IAAAdditiveGaussianNoise","IAAAffine","IAAAffineRotate90","IAAAffineRotate180"]

def plot_multiple_img(img_matrix_list, title_list, ncols, nrows=5,  main_title=""):
    fig, myaxes = plt.subplots(figsize=(20, 15), nrows=nrows, ncols=ncols, squeeze=False)
    fig.suptitle(main_title, fontsize = 30)
    fig.subplots_adjust(wspace=0.3)
    fig.subplots_adjust(hspace=0.3)
    for i, (img, title) in enumerate(zip(img_matrix_list, title_list)):
        myaxes[i // ncols][i % ncols].imshow(img)
        myaxes[i // ncols][i % ncols].set_title(title, fontsize=15)
    plt.show()
    
plot_multiple_img(img_matrix_list, titles_list, ncols = 4,main_title="Different Types of Augmentations with Albumentations")

In [None]:
path = Path('../input/happy-whale-and-dolphin')

In [None]:
def get_x(x): return str(path/'train_images') + os.path.sep + x['image']
def get_y(y): return y['species']

In [None]:
class AlbumentationsTransform(RandTransform):
    "A transform handler for multiple `Albumentation` transforms"
    split_idx,order=None,2
    def __init__(self, train_aug, valid_aug): store_attr()
    
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

In [None]:
def get_train_aug(sz): return A.Compose([
                A.Transpose(p=0.5),
                A.ShiftScaleRotate(p=0.5),
                A.HueSaturationValue(
                hue_shift_limit=0.2, 
                sat_shift_limit=0.2, 
                val_shift_limit=0.2, 
                p=0.5
                ),
                A.RandomSnow(p=1),
                A.HorizontalFlip(p=1), 
                A.VerticalFlip(p=1), 
                A.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1), 
                p=0.5
                ),
               A.CoarseDropout(p=0.9),
               A.Cutout(p=0.5)
])

def get_valid_aug(sz): return A.Compose([
    A.Resize(sz,sz)
], p=1.)

In [None]:
def get_dls(sz):
    item_tfms = AlbumentationsTransform(get_train_aug(sz), get_valid_aug(sz))
    WhaleBlock = DataBlock(
    blocks = (ImageBlock, CategoryBlock),
    splitter = RandomSplitter(valid_pct=0.2),
    get_x=get_x,
    get_y=get_y,
    item_tfms=item_tfms,
    batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)]
     )
    
    return WhaleBlock

In [None]:
block = get_dls(512)

In [None]:
dls = block.dataloaders(train_df, batch_size=CFG.train_bs)
dls.valid.show_batch(max_n=8, nrows=2)

### Work In Progress :)