In [None]:
!pip install efficientnet_pytorch
!pip install torchsummary 

In [None]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from efficientnet_pytorch import EfficientNet
from torchsummary import summary
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm_notebook as tqdm

import numpy as np
import cv2
import matplotlib.pyplot as plt
import random
import time
import pandas as pd
import os

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

In [None]:
from albumentations.pytorch import ToTensorV2

def get_train_transforms():
    return Compose([
            RandomResizedCrop(cfg['input_shape'][0], cfg['input_shape'][0]),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.2),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        ], p=1.)
def get_valid_transforms():
    return Compose([
            CenterCrop(cfg['input_shape'][0], cfg['input_shape'][0], p=1.),
            Resize(cfg['input_shape'][0], cfg['input_shape'][0]),
        ], p=1.)

In [None]:
#hyperparameters
train_dir = '../input/cassava-leaf-disease-classification/train_images'
cfg = {
    'batch_size': 16,
    'input_shape': (512,512,3),
    'epochs': 15,
    'folds': 5,
    'lr': 0.0001,
    'gamma': 4.0,
    'alpha': 2.0,
    'device':'cuda:0'
}
#helper functions:
def read_image(path,shape=None):
    if shape is None:
        return cv2.imread(path)
    else:
        return cv2.resize(cv2.imread(path),shape)
def get_input(img,dims=(20,20)):
    img_cpy = img.copy()
    size = img.shape[0]
    X = random.sample([i for i in range(size)], k =20)
    Y = random.sample([i for i in range(size)], k =20)
    for x,y in zip(X,Y):
        img_cpy[x-dims[0]:x+dims[0],y-dims[1]:y+dims[1],:] = 0
    return img_cpy
def normalize_and_to_tensor(img):
    transform = Compose([Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
                       ToTensorV2(p=1.0)],p=1.0)
    return transform(image=img)['image']

In [None]:
#dataset
class CASSAVA(Dataset):
    def __init__(self,images,input_shape,transforms=None,visualize_mode=False):
        self.images = images
        self.input_size=input_shape[0]
        self.transforms = transforms
        self.visualize_mode = visualize_mode
    def __len__(self):
        return len(self.images)
    def __getitem__(self,idx):
        label = self.images[idx]
        label = read_image(label,(self.input_size,self.input_size))
        if self.transforms is None:
            label = self.transforms(image=label)['image']
        input_image = get_input(label)
        label = cv2.resize(label,(128,128))
        if not self.visualize_mode:
            input_image = normalize_and_to_tensor(input_image)
            label = normalize_and_to_tensor(label)
        return input_image,label
#test
images = os.listdir(train_dir)
img_dirs = [train_dir+'/'+d for d in images]
dataset = CASSAVA(img_dirs,
                  cfg['input_shape'],
                 transforms = get_train_transforms(),
                  visualize_mode = True
                 )
image,label = dataset.__getitem__(3)
plt.imshow(image)
plt.show()
plt.imshow(label)
plt.show()

In [None]:
class MY_MODEL(nn.Module):
    def __init__(self,pretrained = False):
        super(MY_MODEL, self).__init__()
        self.pretrained = pretrained
        self.effn = EfficientNet.from_pretrained('efficientnet-b3')
        self.upconv1 = nn.Upsample(scale_factor=(2, 2))
        self.conv1 = nn.Conv2d(1536,256,(3,3),padding=(1,1))
        self.upconv2 = nn.Upsample(scale_factor=(2, 2))
        self.conv2 = nn.Conv2d(256,64,(3,3),padding=(1,1))
        self.upconv3 = nn.Upsample(scale_factor=(2, 2))
        self.conv3 = nn.Conv2d(64,3,(3,3),padding=(1,1))
        for p in self.effn.parameters():
            p.requires_grad = True
    def forward(self,x):
        x = self.effn.extract_features(x)
        x = self.upconv1(x)
        x = self.conv1(x)
        x = self.upconv2(x)
        x = self.conv2(x)
        x = self.upconv3(x)
        x = self.conv3(x)
        return x

model = MY_MODEL()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.load_state_dict(torch.load('../input/selfsupervision/trained.pt'))
summary(model,(3,512,512))

In [None]:
def loss_fn(y_true,y_pred):
    y_true = y_true.to(device)
    y_pred = y_pred.to(device)
    l = torch.abs(y_true-y_pred)
    shape = l.shape
    l = torch.sum(l,axis=[1,2,3])
    l = l/(shape[-1]*shape[-2]*shape[-3])
    return l

In [None]:
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = ReduceLROnPlateau(optimizer=optim, mode='max', patience=1, verbose=True, factor=0.2)
dataset = CASSAVA(img_dirs,
                  cfg['input_shape'],
                 transforms = get_train_transforms(),
                  visualize_mode = False
                 )

train_loader = DataLoader(dataset=dataset, batch_size=8, shuffle=True, num_workers=2)
total = dataset.__len__()
size = total//30
for epoch in range(cfg['epochs']):
    start = time.time()
    model.train()
    epoch_loss = 0
    for x,y in tqdm(train_loader):
        x = x.to(device)
        y = y.to(device)
        optim.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y,y_pred)
        loss.mean().backward()
        optim.step()
        epoch_loss+=loss.mean().item()
    print('Epoch {:03}: | Loss: {:.3f} | Training time: {}'.format(
            epoch + 1, 
            epoch_loss, 
            str(time.time() - start)[:7]))
    scheduler.step(epoch_loss)
    torch.save(model.state_dict(), 'trained.pt')