In [1]:
import os
import glob
import cv2
import time
import copy
import pickle  # Log dictionary data
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F # stateless functions
import torchvision.transforms as T
import torchvision.models as models

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import cohen_kappa_score,confusion_matrix
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from datetime import datetime

In [2]:
class CFG:
    batch_size = 16
    debug = False
    device = torch.device('cuda')
    dtype = torch.float32
    epochs = 30
    lr = 1e-4
    model_name = 'resnet'
    num_classes = 6
    nworkers = 32
    nfolds = 4
    seed = 524
    TRAIN = '../yi_data/panda-16x128x128-tiles-data/train/'
    LABELS = '../data/train.csv'

# Split Data

In [3]:
train = pd.read_csv(CFG.LABELS).set_index('image_id')
files = sorted(set([p[:32] for p in os.listdir(CFG.TRAIN)]))
train = train.loc[files].reset_index()

if CFG.debug:
    df = train.sample(n=50, random_state=CFG.seed).copy()
else:
    df = train.copy()

# Generate train/validation sets containing the same distribution of isup_grade
splits = StratifiedKFold(n_splits=CFG.nfolds, random_state=CFG.seed, shuffle=True)
splits = list(splits.split(df,df.isup_grade))
# Assign split index to training samples
folds_splits = np.zeros(len(df)).astype(np.int)
for i in range(CFG.nfolds):
    folds_splits[splits[i][1]] = i
df['split'] = folds_splits
df.head()

Unnamed: 0,image_id,data_provider,isup_grade,gleason_score,split
0,0005f7aaab2800f6170c399693a96917,karolinska,0,0+0,2
1,000920ad0b612851f8e01bcc880d9b3d,karolinska,0,0+0,3
2,0018ae58b01bdadc8e347995b69f99aa,radboud,4,4+4,1
3,001c62abd11fa4b57bf7a6c603a11bb9,karolinska,4,4+4,1
4,001d865e65ef5d2579c190a0e0350d8f,karolinska,0,0+0,0


In [4]:
# https://www.kaggle.com/yasufuminakama/panda-se-resnext50-regression-baseline
class TrainDataset(Dataset):
    """Prostate Cancer Biopsy Dataset"""
    
    def __init__(self, df, labels, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file
            root_dir (string): Path to the directory with all images
            transform (callable, optional): Optional transform to be applied on an image sample
        """
        # Shuffle dataframes with fixed seed; otherwise, validation set only get cancerous samples
        self.df = df
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        #worker = torch.utils.data.get_worker_info()
        #worker_id = worker.id if worker is not None else -1
        #start = time.time()
        # https://stackoverflow.com/questions/33369832/read-multiple-images-on-a-folder-in-opencv-python
        img_fns = [fn for fn in glob.glob(f"{CFG.TRAIN}/{self.df['image_id'][idx]}_*.png")]
        imgs = [cv2.imread(fn) for fn in img_fns]
        # (D,W,H)
        img = cv2.hconcat([cv2.vconcat([imgs[0], imgs[1], imgs[2], imgs[3]]),
                           cv2.vconcat([imgs[4], imgs[5], imgs[6], imgs[7]]),
                           cv2.vconcat([imgs[8], imgs[9], imgs[10], imgs[11]]),
                           cv2.vconcat([imgs[12], imgs[13], imgs[14], imgs[15]])])
        
        if self.transform:
            img = self.transform(img)
            
        label = torch.tensor(self.labels[idx])
        #end = time.time()
        return img, label

## Transforms

In [5]:
def get_transforms(phase):
    assert phase in {'train', 'val'}
    
    if phase == 'train':
        return T.Compose([
            T.ToTensor(),
            T.Normalize(
                mean=[0.8776, 0.8186, 0.9090],
                std=[0.1659, 0.2507, 0.1357],
            ),
        ])
    else:
        return T.Compose([
            T.ToTensor(),
            T.Normalize(
                mean=[0.8776, 0.8186, 0.9090],
                std=[0.1659, 0.2507, 0.1357],
            ),
        ])

In [6]:
# Use fold idx as validation set
def data_loader(fold_idx):
    train_idx = df[df['split'] != fold_idx].index
    val_idx = df[df['split'] == fold_idx].index

    train_dataset = TrainDataset(df.loc[train_idx].reset_index(drop=True),
                                 df.loc[train_idx].reset_index(drop=True)['isup_grade'],
                                 transform = get_transforms(phase='train'))
    val_dataset = TrainDataset(df.loc[val_idx].reset_index(drop=True),
                               df.loc[val_idx].reset_index(drop=True)['isup_grade'],
                               transform = get_transforms(phase='train'))
    
    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.nworkers)
    val_loader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.nworkers)
    return train_loader, val_loader

In [7]:
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    
    # Send the model to GPU/CPU
    model = model.to(device=CFG.device)
    
    train_acc_history = []
    val_acc_history = []
    loss_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    preds, targets = [], []
    
    for epoch in range(num_epochs):
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()   # Set model to training phase
            else:
                model.eval()    # Set model to evaluate phase
            
            avg_loss = 0.0
            running_corrects = 0
            
            for inputs, labels in tqdm(dataloaders[phase],
                                       desc='{}/{}({:5s})'.format(epoch+1,num_epochs,phase)):
                inputs = inputs.to(device=CFG.device, dtype=CFG.dtype)
                labels = labels.to(device=CFG.device, dtype=torch.long)
                
                # Zero the parameter gradients
                optimizer.zero_grad()
                
                # Forward, track history if only in training
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    
                    pred = torch.argmax(outputs, 1)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                # Statistics
                avg_loss += loss.item()*(inputs.size(0)/len(dataloaders[phase].dataset))  # len(dataloaders[phase].dataset) not len(dataloaders[phase])
                running_corrects += torch.sum(pred == labels)
                preds.append(pred)
                targets.append(labels)
            
            # End of epoch
            with torch.no_grad():
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

                if phase == 'val':
                    val_acc_history.append(epoch_acc)
                    p = torch.cat(preds).cpu()
                    t = torch.cat(targets).cpu()
                    kappa = cohen_kappa_score(t, p, weights='quadratic')
                    #print(confusion_matrix(t,p)) https://www.dataschool.io/simple-guide-to-confusion-matrix-terminology/
                    # deep copy the model
                    if epoch_acc > best_acc:
                        best_acc = epoch_acc
                        best_model_wts = copy.deepcopy(model.state_dict())
                else:
                    train_acc_history.append(epoch_acc)
                    loss_history.append(avg_loss)
                print('{} Loss: {:4f} Acc: {:4f} Kaap: {:4f}'.format(
                          phase, avg_loss, epoch_acc, kappa if phase=='val' else 1))

                if scheduler is not None and phase == 'train':
                    scheduler.step()       
        #print()
    
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:0f}s'.format(time_elapsed//60, time_elapsed%60))
    print('Best val Acc: {:4f}'.format(best_acc))
    print()
    
    model.load_state_dict(best_model_wts)
    return model, loss_history, train_acc_history, val_acc_history

## AlexNet

In [8]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
            

def initialize_model(model_name, num_classes, feature_extract=False, use_pretrained=False):
    """
    Params:
        feature_extract
            True - fine tunning
            False - fix the model
    """
    model_ft = None
    
    if model_name == 'alexnet':
        """AlexNet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
    elif model_name == 'resnet':
        """Resnet
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
    
    return model_ft

In [None]:
def train_fn(fold):
    model_ft = initialize_model(CFG.model_name, CFG.num_classes, use_pretrained=True)

    optimizer = optim.SGD(model_ft.parameters(),
                          lr=CFG.lr,
                          momentum=.9,
                          nesterov=True)

    print(f'### FOLD: {fold} ###')
    loader_train, loader_val = data_loader(fold)
    best_model, loss_history, train_acc_history, val_acc_history = train_model(model_ft, {'train': loader_train, 'val': loader_val}, F.cross_entropy, optimizer, None, CFG.epochs)

    #print(loss_history)
    #print(train_acc_history)
    #print(val_acc_history)
    return best_model, loss_history, train_acc_history, val_acc_history

log_dict = {'batch_size': CFG.batch_size,
            'epochs': CFG.epochs,
            'learning_rate': CFG.lr,
            'model': CFG.model_name,
            'nworkers': CFG.nworkers,
            'nfolds': CFG.nfolds,
            'random_seed': CFG.seed}
for fold in range(CFG.nfolds):
    best_model, loss_history, train_acc_history, val_acc_history = train_fn(fold)
    log_dict[f'best_mode_{fold}'] = best_model
    log_dict[f'loss_history_{fold}'] = loss_history
    log_dict[f'train_acc_history_{fold}'] = train_acc_history
    log_dict[f'val_acc_history_{fold}'] = val_acc_history

    
# Log results
log_file = f'{CFG.model_name}_{datetime.now().strftime("%m_%d_%Y_%H_%M")}.pkl'
with open(log_file, 'wb') as pkl_file:
    pickle.dump(log_dict, pkl_file)

### FOLD: 0 ###


HBox(children=(FloatProgress(value=0.0, description='1/30(train)', max=493.0, style=ProgressStyle(description_…


train Loss: 1.597193 Acc: 0.340180 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='1/30(val  )', max=165.0, style=ProgressStyle(description_…


val Loss: 1.484559 Acc: 0.411944 Kaap: 0.360660


HBox(children=(FloatProgress(value=0.0, description='2/30(train)', max=493.0, style=ProgressStyle(description_…


train Loss: 1.449412 Acc: 0.425257 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='2/30(val  )', max=165.0, style=ProgressStyle(description_…


val Loss: 1.407039 Acc: 0.435527 Kaap: 0.437617


HBox(children=(FloatProgress(value=0.0, description='3/30(train)', max=493.0, style=ProgressStyle(description_…


train Loss: 1.373367 Acc: 0.457081 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='3/30(val  )', max=165.0, style=ProgressStyle(description_…


val Loss: 1.333834 Acc: 0.467478 Kaap: 0.478278


HBox(children=(FloatProgress(value=0.0, description='4/30(train)', max=493.0, style=ProgressStyle(description_…


train Loss: 1.308976 Acc: 0.489286 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='4/30(val  )', max=165.0, style=ProgressStyle(description_…


val Loss: 1.290547 Acc: 0.475466 Kaap: 0.510595


HBox(children=(FloatProgress(value=0.0, description='5/30(train)', max=493.0, style=ProgressStyle(description_…


train Loss: 1.277310 Acc: 0.501458 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='5/30(val  )', max=165.0, style=ProgressStyle(description_…


val Loss: 1.258031 Acc: 0.489159 Kaap: 0.533860


HBox(children=(FloatProgress(value=0.0, description='6/30(train)', max=493.0, style=ProgressStyle(description_…


train Loss: 1.241997 Acc: 0.513376 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='6/30(val  )', max=165.0, style=ProgressStyle(description_…


val Loss: 1.238577 Acc: 0.495626 Kaap: 0.551801


HBox(children=(FloatProgress(value=0.0, description='7/30(train)', max=493.0, style=ProgressStyle(description_…


train Loss: 1.222309 Acc: 0.518194 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='7/30(val  )', max=165.0, style=ProgressStyle(description_…


val Loss: 1.217474 Acc: 0.500571 Kaap: 0.564430


HBox(children=(FloatProgress(value=0.0, description='8/30(train)', max=493.0, style=ProgressStyle(description_…


train Loss: 1.184283 Acc: 0.539369 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='8/30(val  )', max=165.0, style=ProgressStyle(description_…


val Loss: 1.200544 Acc: 0.508178 Kaap: 0.576490


HBox(children=(FloatProgress(value=0.0, description='9/30(train)', max=493.0, style=ProgressStyle(description_…


train Loss: 1.166696 Acc: 0.546976 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='9/30(val  )', max=165.0, style=ProgressStyle(description_…


val Loss: 1.197496 Acc: 0.511601 Kaap: 0.588192


HBox(children=(FloatProgress(value=0.0, description='10/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 1.145462 Acc: 0.551541 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='10/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.176380 Acc: 0.523013 Kaap: 0.596677


HBox(children=(FloatProgress(value=0.0, description='11/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 1.118148 Acc: 0.566375 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='11/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.165417 Acc: 0.527577 Kaap: 0.604730


HBox(children=(FloatProgress(value=0.0, description='12/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 1.095175 Acc: 0.581210 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='12/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.159290 Acc: 0.529099 Kaap: 0.612224


HBox(children=(FloatProgress(value=0.0, description='13/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 1.065194 Acc: 0.594142 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='13/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.147369 Acc: 0.530240 Kaap: 0.619487


HBox(children=(FloatProgress(value=0.0, description='14/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 1.048838 Acc: 0.597439 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='14/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.141541 Acc: 0.535184 Kaap: 0.625810


HBox(children=(FloatProgress(value=0.0, description='15/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 1.026403 Acc: 0.608343 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='15/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.137598 Acc: 0.536326 Kaap: 0.632152


HBox(children=(FloatProgress(value=0.0, description='16/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.993206 Acc: 0.625206 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='16/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.140455 Acc: 0.537847 Kaap: 0.638109


HBox(children=(FloatProgress(value=0.0, description='17/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.964928 Acc: 0.642323 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='17/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.124986 Acc: 0.545835 Kaap: 0.643895


HBox(children=(FloatProgress(value=0.0, description='18/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.940284 Acc: 0.645239 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='18/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.151936 Acc: 0.534424 Kaap: 0.649291


HBox(children=(FloatProgress(value=0.0, description='19/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.897380 Acc: 0.671104 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='19/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.123679 Acc: 0.543933 Kaap: 0.654812


HBox(children=(FloatProgress(value=0.0, description='20/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.878748 Acc: 0.682769 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='20/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.132435 Acc: 0.548498 Kaap: 0.660412


HBox(children=(FloatProgress(value=0.0, description='21/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.843662 Acc: 0.698364 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='21/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.132062 Acc: 0.540510 Kaap: 0.665516


HBox(children=(FloatProgress(value=0.0, description='22/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.808177 Acc: 0.709522 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='22/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.137245 Acc: 0.537086 Kaap: 0.670601


HBox(children=(FloatProgress(value=0.0, description='23/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.755524 Acc: 0.732852 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='23/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.152415 Acc: 0.543553 Kaap: 0.676032


HBox(children=(FloatProgress(value=0.0, description='24/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.718412 Acc: 0.756435 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='24/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.147797 Acc: 0.535945 Kaap: 0.681097


HBox(children=(FloatProgress(value=0.0, description='25/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.676252 Acc: 0.776721 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='25/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.167077 Acc: 0.538227 Kaap: 0.686411


HBox(children=(FloatProgress(value=0.0, description='26/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.633354 Acc: 0.797642 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='26/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.187197 Acc: 0.532522 Kaap: 0.691287


HBox(children=(FloatProgress(value=0.0, description='27/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.568295 Acc: 0.825662 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='27/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.209458 Acc: 0.535945 Kaap: 0.696827


HBox(children=(FloatProgress(value=0.0, description='28/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.529723 Acc: 0.840877 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='28/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.193802 Acc: 0.533283 Kaap: 0.701891


HBox(children=(FloatProgress(value=0.0, description='29/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.474847 Acc: 0.867250 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='29/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.224756 Acc: 0.537467 Kaap: 0.707175


HBox(children=(FloatProgress(value=0.0, description='30/30(train)', max=493.0, style=ProgressStyle(description…


train Loss: 0.440500 Acc: 0.879168 Kaap: 1.000000


HBox(children=(FloatProgress(value=0.0, description='30/30(val  )', max=165.0, style=ProgressStyle(description…


val Loss: 1.259604 Acc: 0.526816 Kaap: 0.712061
Training complete in 664m 31.610615s
Best val Acc: 0.548498

### FOLD: 1 ###


HBox(children=(FloatProgress(value=0.0, description='1/30(train)', max=493.0, style=ProgressStyle(description_…

"""
Graphs
1. loss vs. iterations
2. Train/Validation accuracy along epoch
"""
plt.subplot(2,1,1)
plt.plot(log_dict['loss_history_0'], 'o')
plt.xlabel('epoch')
plt.ylabel('loss')

plt.subplot(2,1,2)
plt.plot(log_dict['train_acc_history_0'], '-o')
plt.plot(log_dict['val_acc_history_0'], '-o')
plt.legend(['train', 'val'], loc='upper left')
plt.xlabel('epoch')
plt.ylabel('accuracy')

plt.tight_layout(pad=3)
plt.show()

# read python dict back from the file
with open('alexnet_05_26_2020_22_58.pkl', 'rb') as pfile:
    test_dict = pickle.load(pfile)

"""
Graphs
1. loss vs. iterations
2. Train/Validation accuracy along epoch
"""
plt.subplot(2,1,1)
plt.plot(test_dict['loss_history_0'], 'o')
plt.xlabel('epoch')
plt.ylabel('loss')

plt.subplot(2,1,2)
plt.plot(test_dict['train_acc_history_0'], '-o')
plt.plot(test_dict['val_acc_history_0'], '-o')
plt.legend(['train', 'val'], loc='upper left')
plt.xlabel('epoch')
plt.ylabel('accuracy')

plt.tight_layout(pad=3)
plt.show()