In [1]:
import os
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision

import matplotlib.pyplot as plt
%matplotlib inline

%load_ext autoreload
%autoreload 2

## GPU

In [2]:
USE_GPU = True
dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cuda


## Dataloader

In [3]:
from dataloader import PanDataset
from torch.utils.data import DataLoader

PATH_IMG = "/kuacc/users/skoc21/dataset/pannet/wsi-tiles/annotated-qupath-v1/"

dataset = PanDataset(img_dir=PATH_IMG, extension='jpg')
# dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [4]:
# Display image and label.
data_dict = next(iter(dataloader))
train_features, train_labels, fname = data_dict['image'], data_dict['ann'], data_dict['name_img']
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
print(f"File name: {fname}")

NameError: name 'dataloader' is not defined

In [None]:
FromTensor = torchvision.transforms.ToPILImage()
img = train_features.squeeze()
t = FromTensor(img)
plt.imshow(t) # matplotlib
print(f"Label: {train_labels}")

### Train-Val-Test Split - Stratified with Class Counts

In [4]:
import pandas as pd
from sklearn.model_selection import train_test_split

In [5]:
df = pd.DataFrame({"ids":dataset.ids})
train_idx, valid_idx = train_test_split(np.arange(len(df)), test_size=0.3, shuffle=True, stratify=df.ids)

df_valtest = df.loc[ df.index.isin(valid_idx), : ]
val_idx, test_idx = train_test_split(np.arange(len(df_valtest)), test_size=0.5, shuffle=True, stratify=df_valtest.ids)

# Pring Train, Val, Test sizes
print(f"Size Train Set: {len(train_idx)}\tSize Val Set: {len(val_idx)}\tSize Test Set: {len(test_idx)}")

# Class Distributions
print([f"{key}: {round(100*value/len(train_idx),3) }%" for key, value in df.loc[df.index.isin(train_idx), : ].ids.value_counts().to_dict().items()])
print([f"{key}: {round(100*value/len(val_idx),3) }%" for key, value in df.loc[df.index.isin(val_idx), : ].ids.value_counts().to_dict().items()])
print([f"{key}: {round(100*value/len(test_idx),3) }%" for key, value in df.loc[df.index.isin(test_idx), : ].ids.value_counts().to_dict().items()])


Size Train Set: 4243	Size Val Set: 909	Size Test Set: 910
['pannet: 33.349%', 'normal: 33.325%', 'Stroma: 33.325%']
['normal: 35.424%', 'pannet: 33.113%', 'Stroma: 31.463%']
['pannet: 34.615%', 'Stroma: 34.286%', 'normal: 31.099%']


## Data Sampler

In [6]:
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)

In [7]:
pd.DataFrame({"test_idx":test_idx}).to_csv('test_idx.csv')

In [8]:
dataloader_tr = DataLoader(dataset, batch_size=64, sampler=train_sampler)
dataloader_val = DataLoader(dataset, batch_size=64, sampler=val_sampler)
dataloader_ts = DataLoader(dataset, batch_size=64, sampler=test_sampler)

In [9]:
dataloader = {'train': dataloader_tr,
              'val': dataloader_val,
              'test': dataloader_ts}

## Pretrained Model - ResNet with H&E Weights

In [10]:
MODEL_PATH = 'tenpercent_resnet18.ckpt'
RETURN_PREACTIVATION = False  # return features from the model, if false return classification logits
NUM_CLASSES = 3  # only used if RETURN_PREACTIVATION = False


In [11]:
def load_model_weights(model, weights):

    model_dict = model.state_dict()
    weights = {k: v for k, v in weights.items() if k in model_dict}
    if weights == {}:
        print('No weight could be loaded..')
    model_dict.update(weights)
    model.load_state_dict(model_dict)

    return model

In [12]:
model = torchvision.models.__dict__['resnet18'](pretrained=False)
state = torch.load(MODEL_PATH, map_location='cuda:0')
state_dict = state['state_dict']

In [13]:
for key in list(state_dict.keys()):
    state_dict[key.replace('model.', '').replace('resnet.', '')] = state_dict.pop(key)

model = load_model_weights(model, state_dict)

if torch.cuda.is_available():
    model.cuda()

for param in model.parameters():
    param.requires_grad = True
# param.requires_grad = True

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# def accuracy(out, labels):
#     _,pred = torch.max(out, dim=1)
#     return torch.sum(pred==labels).item()

## TRAINING

In [15]:
from typing import (Dict, IO, List, Tuple)
from pathlib import Path
from utils import calculate_confusion_matrix
import time

In [16]:
since = time.time()

# Initialize all the tensors to be used in training and validation.
# Do this outside the loop since it will be written over entirely at each
# epoch and doesn't need to be reallocated each time.
train_all_labels = torch.empty(size=(len(train_idx), ),
                               dtype=torch.long).cpu()
train_all_predicts = torch.empty(size=(len(train_idx), ),
                                 dtype=torch.long).cpu()
val_all_labels = torch.empty(size=(len(val_idx), ),
                             dtype=torch.long).cpu()
val_all_predicts = torch.empty(size=(len(val_idx), ),
                               dtype=torch.long).cpu()

In [23]:
train_helper(dataloader, 
             {"train":len(train_idx), "val":len(val_idx)},
             criterion,
             optimizer,
             scheduler,
             device,
             num_epochs=100,
             start_epoch=1,
             batch_size=64,
             save_interval=10,
             checkpoints_folder=Path('saved_models'),
             num_layers=18,
             classes=[0,1,2],
             num_classes=3
             )

Predicted    0    1    2
Actual                  
0         0.99 0.00 0.00
1         0.01 0.99 0.01
2         0.00 0.01 0.99
Predicted    0    1    2
Actual                  
0         1.00 0.00 0.00
1         0.01 0.99 0.00
2         0.00 0.02 0.98
1,0.0357,0.9896,0.0416,0.9901

Epoch 1 with lr 0.0010000: t_loss: 0.0357 t_acc: 0.9896 v_loss: 0.0416 v_acc: 0.9901

Predicted    0    1    2
Actual                  
0         0.99 0.00 0.01
1         0.01 0.99 0.01
2         0.00 0.01 0.99
Predicted    0    1    2
Actual                  
0         1.00 0.00 0.00
1         0.00 0.99 0.01
2         0.00 0.02 0.98
2,0.0394,0.9885,0.0405,0.9901

Epoch 2 with lr 0.0010000: t_loss: 0.0394 t_acc: 0.9885 v_loss: 0.0405 v_acc: 0.9901

Predicted    0    1    2
Actual                  
0         1.00 0.00 0.00
1         0.01 0.99 0.00
2         0.00 0.01 0.99
Predicted    0    1    2
Actual                  
0         1.00 0.00 0.00
1         0.00 0.99 0.01
2         0.00 0.02 0.98
3,0.0330,0.9910,

In [20]:
def train_helper(
#                  model: torchvision.models.resnet.ResNet,
                 dataloaders: Dict[str, torch.utils.data.DataLoader],
                 dataset_sizes: Dict[str, int],
                 criterion: torch.nn.modules.loss, 
                 optimizer: torch.optim,
                 scheduler: torch.optim.lr_scheduler, 
#                  writer: IO, 
                 device: torch.device, 
                 num_epochs: int,
                 start_epoch: int,
                 batch_size: int, 
                 save_interval: int, 
                 checkpoints_folder: Path,
                 num_layers: int, 
                 classes: List[str],num_classes: int
                 ) -> None:
    
    # Train for specified number of epochs.
    for epoch in range(start_epoch, num_epochs):

        # Training phase.
        model.train(mode=True)

        train_running_loss = 0.0
        train_running_corrects = 0

        # Train over all training data.
        for idx, data_dict in enumerate(dataloaders["train"]):
            train_inputs = data_dict['image'].to(device=device)
            train_labels = data_dict['ann'].to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                __, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs,
                                       target=train_labels)
                train_loss.backward()
                optimizer.step()

            # Update training diagnostics.
            train_running_loss += train_loss.item() * train_inputs.size(0)
            train_running_corrects += torch.sum(
                train_preds == train_labels.data, dtype=torch.double)

            start = idx * batch_size
            end = start + batch_size

            train_all_labels[start:end] = train_labels.detach().cpu()
            train_all_predicts[start:end] = train_preds.detach().cpu()

        calculate_confusion_matrix(all_labels=train_all_labels.numpy(),
                                   all_predicts=train_all_predicts.numpy(),
                                   classes=classes,
                                   num_classes=num_classes)

        # Store training diagnostics.
        train_loss = train_running_loss / dataset_sizes["train"]
        train_acc = train_running_corrects / dataset_sizes["train"]

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Validation phase.
        model.train(mode=False)

        val_running_loss = 0.0
        val_running_corrects = 0

        # Feed forward over all the validation data.
        for idx, data_dict_val in enumerate(dataloaders["val"]):
            val_inputs = data_dict_val['image'].to(device=device)
            val_labels = data_dict_val['ann'].to(device=device)

            # Feed forward.
            with torch.set_grad_enabled(mode=False):
                val_outputs = model(val_inputs)
                _, val_preds = torch.max(val_outputs, dim=1)
                val_loss = criterion(input=val_outputs, target=val_labels)

            # Update validation diagnostics.
            val_running_loss += val_loss.item() * val_inputs.size(0)
            val_running_corrects += torch.sum(val_preds == val_labels.data,
                                              dtype=torch.double)

            start = idx * batch_size
            end = start + batch_size

            val_all_labels[start:end] = val_labels.detach().cpu()
            val_all_predicts[start:end] = val_preds.detach().cpu()

        calculate_confusion_matrix(all_labels=val_all_labels.numpy(),
                                   all_predicts=val_all_predicts.numpy(),
                                   classes=classes,
                                   num_classes=num_classes)

        # Store validation diagnostics.
        val_loss = val_running_loss / dataset_sizes["val"]
        val_acc = val_running_corrects / dataset_sizes["val"]

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        scheduler.step()

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]

        # Remaining things related to training.
        if epoch % save_interval == 0:
            epoch_output_path = checkpoints_folder.joinpath(
                f"resnet{num_layers}_e{epoch}_va{val_acc:.5f}.pt")

            # Confirm the output directory exists.
            epoch_output_path.parent.mkdir(parents=True, exist_ok=True)

            # Save the model as a state dictionary.
            torch.save(obj={
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "epoch": epoch + 1
            },
                       f=str(epoch_output_path))

        print(f"{epoch},{train_loss:.4f},"
                     f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")

        # Print the diagnostics for each epoch.
        print(f"Epoch {epoch} with lr "
              f"{current_lr:.7f}: "
              f"t_loss: {train_loss:.4f} "
              f"t_acc: {train_acc:.4f} "
              f"v_loss: {val_loss:.4f} "
              f"v_acc: {val_acc:.4f}\n")

    # Print training information at the end.
    print(f"\ntraining complete in "
          f"{(time.time() - since) // 60:.2f} minutes")

## ARCHIVE

In [18]:
for epoch in range(1, n_epochs+1):
    running_loss, correct, total = 0.0, 0, 0
    
    print(f'Epoch {epoch}\n')
    for batch_idx, data_dict in enumerate(dataloader['train']):
        data_ = data_dict['image'].to(device)
        target_ = data_dict['ann'].to(device)
        optimizer.zero_grad()
        
        outputs = model(data_)
        loss = criterion(outputs, target_)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _,pred = torch.max(outputs, dim=1)

        correct += torch.sum(pred==target_).item()
        total += target_.size(0)
        if (batch_idx) % print_every == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch, n_epochs, batch_idx, total_step, loss.item()))
    train_acc.append(100 * correct / total)
    train_loss.append(running_loss/total_step)
    print(f'\ntrain-loss: {np.mean(train_loss):.4f}, train-acc: {(100 * correct/total):.4f}')
    
    batch_loss, total_t, correct_t = 0, 0, 0
    with torch.no_grad():
        model.eval()
        for data_dict_t in dataloader['val']:
            data_t, target_t = data_dict_t['image'].to(device), data_dict_t['ann'].to(device)
            outputs_t = model(data_t)
            loss_t = criterion(outputs_t, target_t)
            batch_loss += loss_t.item()
            _,pred_t = torch.max(outputs_t, dim=1)
            correct_t += torch.sum(pred_t==target_t).item()
            total_t += target_t.size(0)
        val_acc.append(100 * correct_t/total_t)
        val_loss.append(batch_loss/total_step_val)
        network_learned = batch_loss < valid_loss_min
        print(f'validation loss: {np.mean(val_loss):.4f}, validation acc: {(100 * correct_t/total_t):.4f}\n')
        
        if network_learned:
            valid_loss_min = batch_loss
            torch.save(model.state_dict(), 'resnet.pt')
            print('Improvement-Detected, save-model')
    scheduler.step()
    model.train()

Epoch 1

Epoch [1/50], Step [0/4243], Loss: 7.0570


KeyboardInterrupt: 

#### Train-Validation Accuracy

In [None]:
fig = plt.figure(figsize=(12,8))
plt.title("Train-Validation Accuracy")
plt.plot(train_acc, label='train')
plt.plot(val_acc, label='validation')
plt.xlabel('num_epochs', fontsize=12)
plt.ylabel('accuracy', fontsize=12)
plt.legend(loc='best')

### Visualize Tiles with Prediction

In [None]:
def visualize_model(net, data_source='test', num_images=20):
    images_so_far = 0
    fig = plt.figure(figsize=(150, 100))
    
    for i, data_dict in enumerate(dataloader[data_source]):
        inputs, labels, fname = data_dict['image'], data_dict['ann'], data_dict['name_img']
        if USE_GPU:
            inputs, labels = inputs.cuda(), labels.cuda()
        outputs = net(inputs)
        _, preds = torch.max(outputs.data, 1)
        preds = preds.to(torch.device('cpu')).detach().numpy() if USE_GPU else preds.numpy()
        for j in range(inputs.size()[0]):
            images_so_far += 1
            ax = plt.subplot(4, num_images//4, images_so_far)
            ax.axis('off')
            ax.set_title(f"Label: {labels[j]} Predict: {preds[j]} \n File: {fname[j].split('/')[-1]}",fontdict={'fontsize': 45})
            img = inputs[j].to(torch.device('cpu'))
            img_ndarray = np.asarray(img)
            img_ndarray = img_ndarray.transpose((1, 2, 0))
            
            plt.imshow(img_ndarray)
            
            if images_so_far == num_images:
                return 

plt.ion()
visualize_model(model)
plt.ioff()

In [None]:
correct_ts = 0
with torch.no_grad():
    model.eval()
    for data_dict_ts in dataloader['test']:
        data_ts, target_ts = data_dict_ts['image'].to(device), data_dict_ts['ann'].to(device)
        outputs_ts = model(data_ts)
        _,pred_ts = torch.max(outputs_ts, dim=1)
        correct_ts += torch.sum(pred_ts==target_ts).item()
    print(100 * correct_ts/len(test_idx))