In [None]:
# based on https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import matplotlib
import tqdm.auto as tqdm

import os
from PIL import Image
from sklearn.metrics import accuracy_score
import torchvision
from torchvision import transforms
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
import time
import copy

plt.style.use('seaborn')

In [2]:
import DiagnosisFunctions.tools as DiagTools
train_split, val_split, test_split, _ = DiagTools.get_splits()

In [3]:
class DiagnosisDataset(Dataset):
    '''
    Define our dataset
    '''
    def __init__(self, path, target, transforms = torch.nn.Sequential()):
        #Input:
        # path:   path to the images.
        # target: target diagnosis.
        
        assert len(path) == len(target), 'path and target should be the same length.'
        
        self.path   = path
        self.target = target
        self.transforms = transforms
        
    def __len__(self):
        return len(self.path)
    
    def __getitem__(self, idx):
        path   = self.path[idx]
        target = self.target[idx]
        
        #Load the image
        im = Image.open(path)
        im = np.array(im) #4th channel is alpha.
        im = torch.tensor(im, dtype=torch.float32).permute(2,0,1) / 255.
        
        if self.transforms is not None:
            im = self.transforms(im)
            
        return im, target, path

In [8]:
train_transforms = nn.Sequential(transforms.RandomResizedCrop(224),
                                transforms.RandomHorizontalFlip(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]))

val_transforms = nn.Sequential(transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]))

train_dataset = DiagnosisDataset(*train_split, transforms=train_transforms)
val_dataset = DiagnosisDataset(*val_split, transforms=val_transforms)
test_dataset = DiagnosisDataset(*test_split)

train_loader = DataLoader(train_dataset, batch_size = 32, shuffle=True, num_workers=4)  
val_loader = DataLoader(val_dataset, batch_size = 32, shuffle=False, num_workers=4)  
test_loader = DataLoader(test_dataset, batch_size = 16, shuffle=False)

datasets = {'train': train_dataset, 'val': val_dataset, 'test': test_dataset}
dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}
dataloaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels, _ in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item()*inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        print()
    
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    model.load_state_dict(best_model_wts)
    return model

In [13]:
model_ft = torchvision.models.efficientnet_b0(pretrained=True)
for param in model_ft.parameters():
    param.requires_grad = False

num_features = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Sequential(nn.Dropout(0.4),
                                       nn.Linear(num_features, 6))

model_ft.to(device)

criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

exp_lr_sceduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [14]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_sceduler, num_epochs=25)

Epoch 0/24
----------
train Loss: 1.8055 Acc: 0.1546
val Loss: 1.8004 Acc: 0.1176

Epoch 1/24
----------
train Loss: 1.8041 Acc: 0.2050
val Loss: 1.7514 Acc: 0.1765

Epoch 2/24
----------
train Loss: 1.7244 Acc: 0.2650
val Loss: 1.6841 Acc: 0.2941

Epoch 3/24
----------
train Loss: 1.6767 Acc: 0.3060
val Loss: 1.6276 Acc: 0.3382

Epoch 4/24
----------
train Loss: 1.5906 Acc: 0.4069
val Loss: 1.5666 Acc: 0.4265

Epoch 5/24
----------
train Loss: 1.5730 Acc: 0.4196
val Loss: 1.5211 Acc: 0.4412

Epoch 6/24
----------
train Loss: 1.5200 Acc: 0.4921
val Loss: 1.4763 Acc: 0.5147

Epoch 7/24
----------
train Loss: 1.4773 Acc: 0.4890
val Loss: 1.4661 Acc: 0.5441

Epoch 8/24
----------
train Loss: 1.4865 Acc: 0.4637
val Loss: 1.4589 Acc: 0.5147

Epoch 9/24
----------
train Loss: 1.4610 Acc: 0.5457
val Loss: 1.4552 Acc: 0.5294

Epoch 10/24
----------


KeyboardInterrupt: 