In [None]:
import os, sys
sys.path = ['../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master', ] + sys.path

In [None]:
#Basic Python and Machine learning libraries
import random, time, cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import skimage.io
from PIL import Image
from scipy import stats
from sklearn import metrics
from sklearn.model_selection import StratifiedKFold
from sklearn.utils import class_weight
from IPython.display import display
from tqdm.notebook import tqdm

#Pytorch and Albumentations(Data Augmentation Library)
import torch
import albumentations
from albumentations.pytorch import ToTensorV2
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.functional import F 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from efficientnet_pytorch import EfficientNet

In [None]:
class Config:
    
    DEBUG = False
    # change for continue training model
    CONTINUE_TRAIN = True
    # last epoch for model
    last_epoch = 10 if CONTINUE_TRAIN else 0
    # last lr for model
    last_lr = 2e-5
    pwd = '/kaggle/working/'
    #change to choose coresponding dataset. VALID: 16, 36
    n_tiles = 36
    assert n_tiles == 16 or n_tiles == 36
    data_dir = '../input/panda-dataset-medium-16-256-256/' if n_tiles == 16 else '../input/panda-dataset-medium-36-256-256/'
    train_img_dir = os.path.join(data_dir, 'train_images')
    test_img_dir = os.path.join(data_dir, 'test_images')
    backbone = 'efficientnet-b1'
    #Add to config understandable definition of target size for BCE or CCE
    SUM_PREDICTION = False
    out_dim = 5 if SUM_PREDICTION else 6
    n_images_to_plot = 16
    n_folds = 2 if DEBUG else 5
    image_size = 256
    batch_size = 2
    num_workers = 4
    num_epochs = 2 if DEBUG else 9
    lr = 3e-4 if not CONTINUE_TRAIN else last_lr
    SEED = 2020
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #Image-net standard mean and std
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

In [None]:
print(Config.device)

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(Config.SEED)

In [None]:
train_df = pd.read_csv(Config.data_dir+'train.csv')
train_df = train_df.sample(100).reset_index(drop=True) if Config.DEBUG else train_df
display(train_df.head())
len(train_df)

In [None]:
if not Config.DEBUG:
    sample_to_drop = train_df[(train_df['isup_grade'] == 2) & (train_df['gleason_score'] == '4+3')].index
    train_df.drop(sample_to_drop, inplace=True)
    train_df.reset_index(inplace=True)
    print(len(train_df))

# Building Dataset

In [None]:
class PANDA_Dataset(Dataset):
    def __init__(self,
                 df,
                 image_size,
                 n_tiles,
                 rand=False,
                 tile_transform=None,
                 img_transform=None
                ):

        self.df = df.reset_index(drop=True)
        self.image_size = image_size
        self.n_tiles = n_tiles
        self.rand = rand
        self.tile_transform = tile_transform
        self.img_transform = img_transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        tiles_path = os.path.join(Config.train_img_dir, self.df['image_id'].values[index]) + '_'
        if Config.SUM_PREDICTION:
            label = np.zeros(5).astype(np.float32)
            label[:self.df['isup_grade'].values[index]] = 1.
        else:
            label = self.df['isup_grade'].values[index]
        
        if self.rand:
            idxes = np.random.choice(list(range(self.n_tiles)), self.n_tiles, replace=False)
        else:
            idxes = list(range(self.n_tiles))
            
        n_row_tiles = int(np.sqrt(self.n_tiles))
        images = np.zeros((self.image_size * n_row_tiles, self.image_size * n_row_tiles, 3))
        for h in range(n_row_tiles):
            for w in range(n_row_tiles):
                i = h * n_row_tiles + w
                tile_i = Image.open(tiles_path + str(i) + '.png')
                tile_i = np.array(tile_i)
                # TRANSFORM FIX
                if self.tile_transform is not None:
                    tile_i = self.tile_transform(image=tile_i)['image']
                h1 = h * self.image_size
                w1 = w * self.image_size
                images[h1:h1+self.image_size, w1:w1+self.image_size] = tile_i
        
        if self.img_transform is not None:
            images = images.astype(np.float32)
            images = self.img_transform(image=images)['image']
        else:
            images = images.astype(np.float32)
            images /= 255
            
        images = images.transpose(2, 0, 1)
        
        return torch.tensor(images), torch.tensor(label)

In [None]:
# The below code will plot down some images for you, given a list of images
def plot_images(images):

    n_images = len(images)

    rows = int(np.sqrt(n_images))
    cols = int(np.sqrt(n_images))

    fig = plt.figure(figsize=(20, 20))
    for i in range(rows*cols):
        ax = fig.add_subplot(rows, cols, i+1)
        ax.set_title('ISUP: '+str(images[i][1]))
        ax.imshow(images[i][0].transpose(0, 1).transpose(1,2).squeeze())
        ax.axis('off')

In [None]:
train_data = PANDA_Dataset(train_df, Config.image_size, Config.n_tiles, False, None)
images = [(image, label) for image, label in [train_data[i] for i in range(Config.n_images_to_plot)]] 
plot_images(images)

In [None]:
skf = StratifiedKFold(Config.n_folds, shuffle=True, random_state=Config.SEED)
train_df['fold'] = -1
for i, (tr_idx, val_idx) in enumerate(skf.split(train_df, train_df['isup_grade'])):
    train_df.loc[val_idx, 'fold'] = i
train_df.head()

In [None]:
train_df.drop(columns=['data_provider', 'gleason_score'], inplace=True)
train_df.head()

In [None]:
train_tile_transforms = albumentations.Compose([
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
])

train_img_transforms = albumentations.Compose([
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.Normalize(mean=Config.mean, std=Config.std, always_apply=True)
])

test_tile_transforms = None

test_img_transforms = albumentations.Compose([
    albumentations.Normalize(mean=Config.mean, std=Config.std, always_apply=True)
])

# Building Model

In [None]:
pretrained_model = {
    'efficientnet-b1': '../input/efnetb1fold2epoch10/efficientnet-b1_fold_2_epoch_10.pt' if Config.CONTINUE_TRAIN else '../input/efficientnet-pytorch/efficientnet-b1-dbc7070a.pth'
}

In [None]:
class BasicModel(nn.Module):
    def __init__(self, backbone, continue_training=False, out_dim=6):
        super(BasicModel, self).__init__()
        self.enet = EfficientNet.from_name(backbone)
        if not continue_training:
            self.enet.load_state_dict(torch.load(pretrained_model[backbone]))
        
        self.fc = nn.Linear(self.enet._fc.in_features, out_dim)
        self.enet._fc = nn.Identity()
    
    def forward(self, x):
        x = self.enet(x)
        x = self.fc(x)
        return x

In [None]:
# # FOR EXTRACTING IMAGE SIZE FROM ENETS
# EfficientNet.get_image_size('efficientnet-b0')

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# model = BasicModel(Config.backbone, Config.out_dim).to(Config.device)
# print(f'The model has {count_parameters(model):,} trainable parameters')

# Defining Training Loop

In [None]:
def train(model, iterator, optimizer, criterion, device):
    
    epoch_loss = 0
    model.train()
    
    for (data, target) in tqdm(iterator):
        
        data = data.to(device)
        target = target.to(device)
        
        optimizer.zero_grad()
        y_pred = model(data)
        loss = criterion(y_pred, target)
        
        loss.backward()
        optimizer.step()
        
        loss_np = loss.detach().cpu().numpy()
        epoch_loss += loss_np
        
    return epoch_loss/len(iterator)

def evaluate(model, iterator, criterion, device):
    
    epoch_loss = 0
    preds_list = []
    targets_list = []
    model.eval()
    
    with torch.no_grad():
        
        for (data, target) in tqdm(iterator):
        
            data = data.to(device)
            target = target.to(device)

            logits = model(data)
            loss = criterion(logits, target)
            pred = torch.argmax(logits, dim=1)
            
            preds_list.append(pred)
            targets_list.append(target)

            loss_np = loss.detach().cpu().numpy()
            epoch_loss += loss_np

    preds_list = torch.cat(preds_list).cpu().numpy()
    targets_list = torch.cat(targets_list).cpu().numpy()
    
    metric = metrics.cohen_kappa_score(preds_list, targets_list, weights='quadratic')
    
    return epoch_loss/len(iterator), metric

In [None]:
def fit_model(model, model_name, train_iterator, valid_iterator, optimizer, scheduler, loss_criterion, device, n_epochs, last_epoch, fold):
    """ Fits a dataset to model"""
    #Setting best validation loss to infinity :p
    best_valid_metric = -1.
    
    train_losses = []
    valid_losses = []
    valid_metric_scores = []
    
    #Let's loop through our data
    for epoch in range(n_epochs):
    
        start_time = time.time()
        
        print(f'Epoch: {epoch+last_epoch+1:02} | Training:')
        train_loss = train(model, train_iterator, optimizer, loss_criterion, device)
        print(f'Epoch: {epoch+last_epoch+1:02} | Validating:')
        valid_loss, valid_metric_score = evaluate(model, valid_iterator, loss_criterion, device)
        
        scheduler.step()
        
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        valid_metric_scores.append(valid_metric_score)

        #Let's keep updating our model, so that we save only the best one at the end
        if valid_metric_score > best_valid_metric:
            print('Score Increased ({:.6f} --> {:.6f}).  Saving model ...'.format(best_valid_metric, valid_metric_score))
            best_valid_metric = valid_metric_score
            torch.save(model.state_dict(), f'{model_name}_fold_{fold}_epoch_{epoch+last_epoch+1}.pt')
    
        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)
        
        content = f'Epoch: {epoch+last_epoch+1:02}, lr: {optimizer.param_groups[0]["lr"]:.7f}  | Epoch Time: {epoch_mins}m {epoch_secs}s' + '\n' + f'\tTrain Loss: {train_loss:.3f}' + '\n' + f'\t Val. Loss: {valid_loss:.3f} |  Val. Metric Score: {valid_metric_score:.3f}'
        with open(f'training_log.txt', 'a') as appender:
            appender.write(content + '\n\n')
            
        #Printing and returning some important statistics
        print(content)
        
    return pd.DataFrame({f'{model_name}_fold_{fold}_Training_Loss':train_losses,  
                        f'{model_name}_fold_{fold}_Validation_Loss':valid_losses, 
                        f'{model_name}_fold_{fold}_Valid_Metric_Score':valid_metric_scores})

In [None]:
#This will simply plot the training statistics we returned
def plot_training_statistics(train_stats, model_name, fold):
    
    fig, axes = plt.subplots(2, figsize=(15,15))
    axes[0].plot(train_stats[f'{model_name}_fold_{fold}_Training_Loss'], label=f'{model_name}_fold_{fold}_Training_Loss')
    axes[0].plot(train_stats[f'{model_name}_fold_{fold}_Validation_Loss'], label=f'{model_name}_fold_{fold}_Validation_Loss')
    axes[1].plot(train_stats[f'{model_name}_fold_{fold}_Valid_Metric_Score'], label=f'{model_name}_fold_{fold}_Valid_Metric_Score')
    
    axes[0].set_xlabel("Number of Epochs"), axes[0].set_ylabel("Loss")
    axes[1].set_xlabel("Number of Epochs"), axes[1].set_ylabel("Score on Metric")
    
    axes[0].legend(), axes[1].legend()

# Training with K-Fold CV

In [None]:
fold = 2
#Make Train and Valid DataFrame from fold
train_df_fold = train_df[train_df['fold'] != fold]
valid_df_fold = train_df[train_df['fold'] == fold]

#compute Clas weights
class_weights = class_weight.compute_class_weight('balanced', np.unique(train_df_fold['isup_grade']), train_df_fold['isup_grade'])
print(class_weights)
class_weights = torch.Tensor(class_weights)

In [None]:
print(f"Fitting on Fold {fold+1}")
#Build and load Dataset
train_data = PANDA_Dataset(train_df_fold, Config.image_size, Config.n_tiles, False, train_tile_transforms, train_img_transforms)
valid_data = PANDA_Dataset(valid_df_fold, Config.image_size, Config.n_tiles, False, test_tile_transforms, test_img_transforms)
train_iterator = DataLoader(train_data, shuffle=True, batch_size=Config.batch_size, num_workers=Config.num_workers)
valid_iterator = DataLoader(valid_data, batch_size=Config.batch_size, num_workers=Config.num_workers)
    
#Initialize model, loss and optimizer
model = BasicModel(Config.backbone, Config.CONTINUE_TRAIN, Config.out_dim).to(Config.device)
if Config.CONTINUE_TRAIN:
    model.load_state_dict(torch.load(pretrained_model[Config.backbone], map_location=Config.device))
loss_criterion = None
if Config.SUM_PREDICTION:
    loss_criterion = nn.BCEWithLogitsLoss().to(Config.device)
else:
    loss_criterion = nn.CrossEntropyLoss(weight=class_weights).to(Config.device)

optimizer=optim.Adam(model.parameters(), lr=Config.lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, Config.num_epochs)
    
#Fit the model and visualize the training curves
train_stats = fit_model(model, 'efficientnet-b1', train_iterator, valid_iterator, 
                    optimizer, scheduler, loss_criterion, Config.device, Config.num_epochs, Config.last_epoch, fold)
plot_training_statistics(train_stats, 'efficientnet-b1', fold)
    
#Just making sure that the output looks neat
print('\n')
print('-------------------------------------------------------')
print('\n')