In [None]:
%matplotlib inline

Finetuning Torchvision Models
=============================

Author: Wataru Uegami, MD

2021/3/27

In [None]:
from __future__ import print_function 
from __future__ import division
import torch
from torchvision.models import resnet
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from pathlib import Path
import shutil
from tqdm import tqdm
from functools import partial

from mocotools import mocoutil

import pandas as pd
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

model_name = "resnet"
batch_size = 64
num_epochs = 150

# Flag for feature extracting. When False, we finetune the whole model, 
#   when True we only update the reshaped layer params
feature_extract = False

# 2.5x model
We call "2x" instead of "2.5x" in following code.

Execute only when fine-tune 2.5x model. (Skip for 5x and 20x)

In [None]:
# Top level data directory. Here we assume the format of the directory conforms 
#   to the ImageFolder structure
data_dir = Path('path/to/tiles/2x')

# Number of classes in the dataset
num_classes = 4

# Load pretrained CNN feature extractor
checkpoint = torch.load('/path/to/checkpoints/2x_epoch200.pth')

cluster = pd.read_csv('/path/to/cluster_results/mgn2x.csv', index_col=0)
models_out = '/path/to/finalmodel_Mar10/2x/'

def remap(col):
    if col in {0,3, 4, 7, 14, 19, 22, 26, 27, 29}:
        return 'NearNormal'
    elif col in {2, 13,15,16,24, 25, 28}:
        return 'CellularTissue'
    elif col in {1, 5, 8, 10, 12, 18}:
        return 'AcellularFibroticIP'
    elif col in {11, 17, 20}:
        return 'Exclude'
    elif col in {6, 9, 21, 23}:
        return 'Other'
    
cluster['feat'] = cluster['k30'].apply(remap)

n_train = 13000

minority = 'CellularTissue'

# 5x

Execute only when fine-tune 5x model (skip for 2.5x and 20x)

In [None]:
# Top level data directory. Here we assume the format of the directory conforms 
#   to the ImageFolder structure
data_dir = Path('path/to/tiles/5x')

# Number of classes in the dataset
num_classes = 8

# Load pretrained CNN feature extractor
checkpoint = torch.load('/path/to/checkpoints/5x_epoch160.pth')

cluster5x = pd.read_csv('/path/to/cluster_results/mgn5x.csv', index_col=0)
models_out = '/path/to/finalmodel_Mar10/5x/'

# Need to define how to integrate the cluster referring the montage
def remap(col):
    if col in {16, 28, 61}:
        return 'LymphoidFollicle'
    
    elif col in {0, 2, 5, 6, 7, 12, 26, 30, 32,37, 42, 45, 46, 51, 57, 58, 64}:
        return 'CellularIP_NSIP'
    
    elif col in {9, 10, 38, 43, 44,49, 56, 60, 74, 78, 79}:
        return 'CellularFibroticIP'
    
    elif col in {8, 13, 15, 19, 24, 39, 54, 67}:
        return 'CompleteNormal'
    
    elif col in {11, 21, 22, 23, 27, 33, 36, 41, 47, 48, 55, 59, 65, 66, 69, 71,72,73,76}:
        return 'Exclude'
    
    elif col in {4, 14, 17, 18, 50, 53, 68, 77}:
        return 'Accellular_fibrosis'
    
    elif col in {1, 35, 63, 75}:
        return 'edge'
    
    elif col in {20, 34, 52, 62}:
        return 'pale'
    
    else: # 3, 25, 29, 31, 40, 70
        return 'Other'
    
   
cluster5x['feat'] = cluster5x['k80'].apply(remap)

cluster = cluster5x

n_train = 2000

minority = 'LymphoidFollicle'

# 20x
In 20x, the cluster and images are already reclassified manually. As the number of each classes are different, here we

In [None]:
import random
import glob

data_dir = Path('/path/to/tiles/20x')

# Number of classes in the dataset
num_classes = 8

# Load pretrained CNN feature extractor
checkpoint = torch.load('/path/to/checkpoints/20x_epoch200.pth')

cluster20x = pd.read_csv('/path/to/cluster_results/mgn20x_2.csv', index_col=0)

models_out = '/path/to/finalmodel_Mar10/mgn20x_4/'

features = ['DF_true', 'elastosis', 'fat', 'Immature_fibroblasts',
            'lymphocyte_dense', 'resp_epithelium', 'mucos', 'other']

for feat in features:
    
    print(f'Copy files: {feat}')
    imgs = glob.glob('/path/to/train20x_4000each/' + feat + '/*.jpeg') 
    
    num_cases = 4000
    num_train = 3500
        
    random.shuffle(imgs)
    
    os.makedirs(f'/path/to/finetune_Mar10/20x/train/{feat}', exist_ok=True)
    os.makedirs(f'/path/to/finetune_Mar10/20x/val/{feat}', exist_ok=True)
    
    for i, img in tqdm(enumerate(imgs), total = num_cases):
        if i< num_train:
            shutil.copy(img, f'/path/to/finetune_Mar10/20x/train/{feat}/')
        else:
            shutil.copy(img, f'path/to/finetune_Mar10/20x/val/{feat}/')
        
        if i == num_cases:
            break

Inputs
------

`data dir` have to be define like this:

```
data_dir
 |- train
     |- cls0
         |- 0_1.jpeg
         |- 0_2.jpeg
         ...
     |- cls1
     |- cls3
     ....
 |- val
     |- cls0
         |- 0_1.jpeg
         |- 0_2.jpeg
         ...
     |- cls1
     |- cls3
     ....
```

In [None]:
import random
for feat in pd.unique(cluster['feat']):
    if feat == 'Exclude':
        continue
    print(feat)
    _df = cluster[cluster['feat'] == feat]
    

    df_downsample = _df.sample(n_shortest)
    df_downsample['train'] = [i<n_train for i in df_downsample.reset_index().index]

    
    path = df_downsample['path']
    feat = df_downsample['feat']
    train = df_downsample['train']
    
    for p, f, t in tqdm(zip(path, feat, train)):
        p = Path(p)
        if t:
            t = 'train'
        else:
            t = 'val'
            
        p = Path(p)
        case_name = p.parent.name
        
        dst = data_dir.joinpath(t).joinpath(f).joinpath(case_name + p.name)
        os.makedirs(dst.parent, exist_ok=True)
        
        shutil.copy(p, dst)

In [None]:
model = mocoutil.ModelMoCo(dim=128, K=4096,m=0.99,T=0.1,arch='resnet18').cuda()

print(model.load_state_dict(checkpoint['state_dict']))

model = model.encoder_q

Helper Functions
----------------

In [None]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
    since = time.time()

    val_acc_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    if is_inception and phase == 'train':
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            if (epoch==1) | (epoch % 2 == 0):
                torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(),},
                           f'{models_out}/ep{str(epoch)}.pth')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, val_acc_history

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

Initialize and Reshape the Networks
-----------------------------------

In [None]:
snet = []
for name, module in model.net.named_children():
    snet.append(module)
    if isinstance(module, nn.AdaptiveAvgPool2d):
        snet.append(nn.Flatten(1))
        snet.append(nn.Linear(512, num_classes))
        break
model.net = nn.Sequential(*snet)

### Load Data

In [None]:
normalize = transforms.Normalize(mean=[0.85, 0.7, 0.78], std=[0.15, 0.24, 0.2])

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomRotation(20),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(brightness=0.08, # 0.4
                                   contrast=0.2, # 0.4
                                   saturation=0.7,
                                   hue=0.03)  # not strengthened  # 0.1
        ], p=1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.85, 0.7, 0.78], std=[0.15, 0.24, 0.2])
    ]),
    'val': transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.85, 0.7, 0.78], std=[0.15, 0.24, 0.2])
    ]),
}

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_datasets['train'].class_to_idx

Create the Optimizer
--------------------

In [None]:
# Send the model to GPU
model = model.to(device)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are 
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
params_to_update = model.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(params_to_update, lr=0.0001)

Run Training and Validation
--------------------------------

In [None]:
# Setup the loss fxn
criterion = nn.CrossEntropyLoss()

# Train and evaluate
model, hist = train_model(model, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))