PyTorch Dataset: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

NOTE:
1. The validation range is changed to small size for debug.
2. Mean:  [0.90960454 0.81946206 0.87811487]
   Std:  [0.13244118 0.24944844 0.16392948]

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F # stateless functions
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import sampler

import os
import pandas as pd
from skimage import io, transform
import numpy as np
import copy

import matplotlib.pyplot as plt

#import torchvision.datasets as dset
import torchvision.transforms as T
import torchvision.models as models

import time
from collections import defaultdict

In [2]:
image_cropped_dir = '../yi_data/train_crop_128'
NUM_TRAIN = 6800 #8492

USE_GPU = False
dtype = torch.float32   # use float throughout the training
print_every = 100

SCALE_SZ = 128
BATCH_SZ = 8

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
print('using device:', device)

using device: cpu


In [3]:
# Customized Dataset
class ProstateCancerDataset(Dataset):
    """Prostate Cancer Biopsy Dataset"""
    
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file
            root_dir (string): Path to the directory with all images
            transform (callable, optional): Optional transform to be applied on an image sample
        """
        # Shuffle dataframes with fixed seed; otherwise, validation set only get cancerous samples
        self.cancer_df = pd.read_csv(csv_file).sample(frac=1, random_state=1)
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.cancer_df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, f'{self.cancer_df.iloc[idx, 0]}_0.png')
        # (D,W,H)
        img = io.imread(img_path)
        isup = self.cancer_df.iloc[idx, 2]
        gleason = self.cancer_df.iloc[idx, 3]
        
        if self.transform:
            img = self.transform(img)
        sample = {'image': img, 'isup_grade': isup, 'gleason_score': gleason}
        return sample        

##Customize Transforms
class Rescale(object):
    """Rescale the image sample to the given size
    Args:
        output_size (tuple): Desired output size. Output is matched to output_size.
    """
    
    def __init__(self, output_size):
        assert isinstance(output_size, tuple)
        self.output_size = output_size
        
    def __call__(self, sample):
        img = transform.resize(sample['image'], self.output_size)
        return {'image': img, 'isup_grade': sample['isup_grade'], 'gleason_score': sample['gleason_score']}
    
"""
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    
    def __call__(self, sample):
        img = sample['image']        
        # Swap color axis to [C,H,W]
        img = img.transpose(2,0,1)
        return {'image': img, 'isup_grade': sample['isup_grade'], 'gleason_score': sample['gleason_score']}
"""

# Data Preparation

In [4]:
# Compse tranforms of totensor
# More transformer to try: crop, normalize, etc.
# And transformer is a useful tool for data augmentation
biopsy_train = ProstateCancerDataset(csv_file='train_512.csv',
                                     root_dir=image_cropped_dir,
                                     transform=T.Compose([
                                                 T.ToTensor(),
                                                 T.Normalize((0.90960454,0.81946206,0.87811487),
                                                             (0.13244118,0.24944844,0.16392948)),
                                     ]))

loader_train = DataLoader(biopsy_train, batch_size=BATCH_SZ, num_workers=4,
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

loader_val = DataLoader(biopsy_train, batch_size=BATCH_SZ, num_workers=4,
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 8492)))  #NUM_TRAIN+200  Use smaller size for debug

isup_distr = defaultdict(int)
for batch_i, batch_sample in enumerate(loader_train):
    #for grade in batch_sample['isup_grade']:
        #print(int(grade))
    #    isup_distr[int(grade)] += 1
    print(batch_sample['image'][0].shape)
    #for s_i, sample in enumerate(batch_sample['isup_grade']):
    #    print(sample)
    print(batch_sample['isup_grade'])
    print(batch_sample['image'][0])
    plt.imshow(batch_sample['image'][0].permute(1,2,0))
    assert False
print(isup_distr)

# Two-Layer Network

In [31]:
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval() # Set model to evaluation mode
    with torch.no_grad():
        for batch in loader:
            x = batch['image'].to(device=device, dtype=dtype)
            y = batch['isup_grade'].to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds==y).sum()
            num_samples += preds.size(0)
            #print(batch['isup_grade'].unique(), batch['isup_grade'][preds==y].unique())
            #print(batch['isup_grade'])
        acc = float(num_correct) / num_samples
        print('Got {:d}/{:d} correct {:.2f}'.format(num_correct, num_samples, acc*100))

def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)

def train_sequential(model, optimizer, scheduler, epochs=1):
    """
    Train a model using PyTorch Sequential API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model.
    - epochs: The expected usage number of each image.
    
    Output: Print model accuracies.
    """
    model = model.to(device=device)
    for e in range(epochs):
        print(f'Epoch {e}')
        for t, batch in enumerate(loader_train):
            x = batch['image'].to(device=device, dtype=dtype)
            y = batch['isup_grade'].to(device=device, dtype=torch.long)
            
            scores = model(x)
            loss = F.cross_entropy(scores, y)
            
            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()
            
            # Backward pass: compute the gradient of the loss with respect to
            # each parameter of the model.
            loss.backward()
            
            # Update the parameters of the model using the gradients computed by
            # the backward pass.
            optimizer.step()
            
            if t % print_every == 0:
                print('Iteration {:d}, loss = {:.4f}'.format(t, loss.item()))
                check_accuracy(loader_val, model)
                # Decay learning rate after each validation check
        scheduler.step()

In [33]:
class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)
    
hidden_layer_size = 100
learning_rate = 3e-3

model = nn.Sequential(
    Flatten(),
    nn.Linear(3*SCALE_SZ*SCALE_SZ, hidden_layer_size),
    nn.ReLU(),
    nn.Linear(hidden_layer_size, 6)
)

# Use Nesterov momentum
"""
optimizer = optim.SGD(model.parameters(),
                      lr=learning_rate,
                      momentum=.9,
                      nesterov=True)
"""
optimizer = optim.Adam(model.parameters(),
                       lr=learning_rate)

# Learning Rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)

#train_sequential(model, optimizer, scheduler, 30)

Epoch 0
Iteration 0, loss = 1.7992
Got 436/1692 correct 25.77
Iteration 100, loss = 2.9562
Got 453/1692 correct 26.77
Epoch 1
Iteration 0, loss = 1.9555
Got 443/1692 correct 26.18
Iteration 100, loss = 1.6102
Got 450/1692 correct 26.60
Epoch 2
Iteration 0, loss = 2.2169
Got 452/1692 correct 26.71
Iteration 100, loss = 1.5697
Got 454/1692 correct 26.83
Epoch 3
Iteration 0, loss = 1.6583
Got 450/1692 correct 26.60
Iteration 100, loss = 1.7210
Got 459/1692 correct 27.13
Epoch 4
Iteration 0, loss = 1.5495
Got 460/1692 correct 27.19
Iteration 100, loss = 1.5922
Got 463/1692 correct 27.36
Epoch 5
Iteration 0, loss = 1.5628
Got 461/1692 correct 27.25
Iteration 100, loss = 1.5773
Got 461/1692 correct 27.25
Epoch 6
Iteration 0, loss = 1.5719
Got 461/1692 correct 27.25
Iteration 100, loss = 1.6990
Got 463/1692 correct 27.36
Epoch 7
Iteration 0, loss = 1.6314
Got 463/1692 correct 27.36
Iteration 100, loss = 1.5721
Got 464/1692 correct 27.42
Epoch 8
Iteration 0, loss = 1.6030
Got 464/1692 correct 

# CNN
Reference: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

In [5]:
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    
    val_acc_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs-1}')
        print('-'*10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()   # Set model to training phase
            else:
                model.eval()    # Set model to evaluate phase
                
            running_loss = 0.0
            running_corrects = 0
            
            for batch in dataloaders[phase]:
                inputs = batch['image'].to(device=device, dtype=dtype)
                labels = batch['isup_grade'].to(device=device, dtype=torch.long)
                
                # Zero the parameter gradients
                optimizer.zero_grad()
                
                # Forward, track history if only in training
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    
                    _, preds = torch.max(outputs, 1)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                # Statistics
                running_loss += loss.item() * inputs.size(0)
                print(loss.item())
                running_corrects += torch.sum(preds == labels)
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            print('{} Loss: {:4f} Acc: {:4f}'.format(phase, epoch_loss, epoch_acc))
            
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)
                
            if scheduler is not None and phase == 'train':
                scheduler.step()
                
        print()
    
    tim_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:0f}s'.format(time_elapsed//60, time_elapsed%60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model, val_acc_history


def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
            

def initialize_model(num_classes, feature_extract=False, use_pretrained=False):
    model_ft = None
    input_size = 0
    
    model_ft = models.alexnet(pretrained=use_pretrained)
    set_parameter_requires_grad(model_ft, feature_extract)
    num_ftrs = model_ft.classifier[6].in_features
    model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
    input_size = 128
    
    return model_ft, input_size

In [None]:
model_ft, input_size = initialize_model(6)
#print(model_ft)
# Send the model to GPU/CPU
model_ft = model_ft.to(device)

optimizer = optim.SGD(model_ft.parameters(),
                      lr=3e-3,
                      momentum=.9,
                      nesterov=True)

train_model(model_ft, {'train': loader_train, 'val': loader_val}, F.cross_entropy, optimizer, None, 30)

Epoch 0/29
----------
1.7987030744552612
1.7921947240829468
1.7934143543243408
1.793198585510254
1.7961546182632446
1.7917536497116089
1.7922892570495605
1.792385220527649
1.7897076606750488
1.7856415510177612
1.7851194143295288
1.78425931930542
1.7910858392715454
1.784632921218872
1.7931833267211914
1.7865554094314575
1.7844675779342651
1.7762876749038696
1.7827672958374023
1.785706639289856
1.7924561500549316
1.788851261138916
1.765915036201477
1.8041828870773315
1.7834748029708862
1.7823799848556519
1.7306456565856934
1.7465015649795532
1.7620561122894287
1.7874761819839478
1.7582727670669556
1.7756716012954712
1.7582067251205444
1.775965690612793
1.7657396793365479
1.7679896354675293
1.7806487083435059
1.788172960281372
1.8154157400131226
1.7610918283462524
1.7702339887619019
1.760696530342102
1.7551591396331787
1.7184319496154785
1.7187687158584595
1.7104848623275757
1.7481317520141602
1.737756609916687
1.6920806169509888
1.682916283607483
1.825808048248291
1.7016147375106812
1.73

1.620553970336914
1.5711740255355835
1.8784053325653076
1.7156336307525635
1.607582688331604
1.500455617904663
1.4382730722427368
1.55473792552948
1.2395954132080078
1.459681749343872
1.2200186252593994
1.2540557384490967
1.6765419244766235
1.8260588645935059
1.764514684677124
1.7051595449447632
1.7724920511245728
1.7358633279800415
1.7339038848876953
1.6828922033309937
1.6453100442886353
1.7669180631637573
1.753909707069397
1.7175527811050415
1.6083180904388428
1.6292331218719482
1.4518362283706665
2.5757904052734375
1.6268682479858398
1.6816997528076172
1.8932547569274902
1.5548014640808105
1.7411856651306152
1.7440028190612793
1.706712007522583
1.3585196733474731
1.8566972017288208
1.7736257314682007
1.5854793787002563
1.6922791004180908
1.8603793382644653
1.9649503231048584
1.82260000705719
1.2833878993988037
1.678894281387329
1.854671597480774
2.0131731033325195
1.7045143842697144
2.009432077407837
1.6525026559829712
1.5831513404846191
1.5005433559417725
1.746402382850647
1.650898

1.4922116994857788
1.5536949634552002
1.7004472017288208
1.511612892150879
1.6069865226745605
1.6852757930755615
1.7052737474441528
1.7154240608215332
1.7981585264205933
1.5521705150604248
1.6818965673446655
1.8293503522872925
1.676047921180725
1.6083037853240967
1.5755259990692139
1.7964463233947754
1.9672935009002686
1.7975906133651733
1.6433981657028198
1.5629537105560303
1.5461704730987549
1.58201265335083
1.540296196937561
1.3994938135147095
1.527445912361145
1.505562663078308
1.848257064819336
1.4995269775390625
1.5328534841537476
1.709526777267456
1.5858681201934814
1.7308367490768433
1.4565699100494385
1.5845714807510376
1.715759515762329
1.9449880123138428
1.747989535331726
1.4579178094863892
1.561248779296875
1.755387544631958
1.505025863647461
1.4999420642852783
1.6866315603256226
1.829169750213623
1.5194271802902222
1.5644103288650513
1.754185676574707
1.6798090934753418
1.522700548171997
1.452595829963684
1.6167634725570679
1.3658844232559204
1.6980769634246826
1.547777295

1.692906379699707
1.7586263418197632
1.5378270149230957
1.8161667585372925
1.561941385269165
1.7157227993011475
1.4875445365905762
1.782739281654358
1.5515114068984985
1.6851840019226074
1.538867712020874
1.9411420822143555
1.5244650840759277
1.4863686561584473
1.8771976232528687
1.5387444496154785
1.2752413749694824
1.722691535949707
1.8264731168746948
1.2997307777404785
1.4314417839050293
1.523948311805725
1.5667555332183838
1.8632055521011353
1.6925053596496582
1.7631605863571167
1.5826400518417358
1.6108272075653076
1.5636632442474365
1.6942932605743408
1.6057603359222412
1.2364166975021362
1.4545246362686157
1.4303858280181885
2.093820333480835
1.5046076774597168
1.6349008083343506
1.6021122932434082
1.8407396078109741
1.6863970756530762
1.7366377115249634
1.7208791971206665
1.777999758720398
1.6315019130706787
1.7359776496887207
1.6243679523468018
1.6567332744598389
1.6802934408187866
1.5841792821884155
1.5524824857711792
1.757834792137146
1.7453752756118774
1.484033465385437
2.0

1.5918970108032227
1.7896312475204468
1.774358868598938
1.6019277572631836
1.5400336980819702
1.7424395084381104
1.6729893684387207
1.5054007768630981
1.4986331462860107
1.8255164623260498
1.6661334037780762
1.7993214130401611
1.489223837852478
1.5076017379760742
1.9242703914642334
1.7980784177780151
1.436691403388977
1.6329180002212524
1.5358890295028687
1.6068414449691772
1.6758490800857544
1.4583264589309692
1.4516061544418335
1.6087915897369385
1.4033293724060059
1.6774191856384277
1.7959251403808594
1.7093777656555176
1.5110646486282349
1.5108927488327026
1.5554594993591309
1.454115629196167
1.5384337902069092
1.556107997894287
1.8024406433105469
1.7548983097076416
1.4455161094665527
1.524855136871338
1.6270482540130615
1.837365984916687
1.6662375926971436
1.5664666891098022
1.4747463464736938
1.6123472452163696
1.9369492530822754
1.5815644264221191
1.731215476989746
1.715935230255127
1.690996766090393
1.6136285066604614
1.4064916372299194
1.9967756271362305
1.30281662940979
1.660

1.8464024066925049
1.6307027339935303
1.6926482915878296
1.6555031538009644
1.7268983125686646
1.706883430480957
1.5863263607025146
1.731635570526123
1.6586296558380127
1.7958011627197266
1.7320363521575928
1.6815197467803955
1.7496429681777954
1.4621002674102783
1.7885212898254395
1.4636452198028564
1.6219935417175293
1.5072396993637085
1.8892860412597656
1.4798041582107544
1.6340463161468506
1.9510273933410645
1.1990900039672852
1.4534229040145874
1.5962026119232178
1.5080595016479492
1.7885284423828125
1.9299007654190063
1.6561393737792969
1.5661267042160034
1.4733939170837402
1.6577496528625488
1.6040633916854858
1.6108191013336182
1.784909963607788
1.6016236543655396
1.642344355583191
1.67603600025177
1.430366039276123
1.6105632781982422
1.4785325527191162
1.6599369049072266
1.5242280960083008
1.4957258701324463
1.4145804643630981
1.639807939529419
1.7872921228408813
1.5884616374969482
1.5067837238311768
1.5624094009399414
1.108612060546875
1.9508280754089355
1.479630947113037
1.6

1.8183631896972656
1.6803756952285767
1.5382440090179443
1.6171555519104004
1.7384930849075317
1.6042311191558838
1.745149850845337
1.6500133275985718
1.5808660984039307
1.7635574340820312
1.6043320894241333
1.6143338680267334
1.5395435094833374
1.5577404499053955
1.4968904256820679
2.1379096508026123
1.5835132598876953
1.4348065853118896
1.8314677476882935
1.6136369705200195
1.4143065214157104
1.5802191495895386
1.5141247510910034
1.579267978668213
1.7293388843536377
1.9513813257217407
1.6746577024459839
1.3277440071105957
1.4078633785247803
1.342188835144043
1.7879621982574463
1.2175965309143066
1.4905846118927002
2.1582798957824707
1.4925369024276733
1.684521198272705
1.4468917846679688
1.4299203157424927
1.7809009552001953
1.5828180313110352
1.272901177406311
1.1412664651870728
1.5812917947769165
1.6310280561447144
1.423741340637207
1.6705784797668457
2.051487922668457
1.3144456148147583
1.4918065071105957


In [8]:
"""
Graphs
1. loss vs. iterations
2. Train/Validation accuracy along epoch
"""
a = torch.tensor([1, 2, 3, 4])
b = torch.tensor([2, 1, 3, 4])
print(a[a==b])

tensor([3, 4])
