In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
from scipy import ndimage
import os, sys
import math
import pickle
import data_utils as datutil
import datetime as dt
import hmc
import torch.utils.data as Data
from models import *
import gpytorch
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm

Retinopathy dataset definition: removes bad images that cant be read, includes first 30000 images from test set into training set so that the train:test split is roughly 75:25

Also there is a weights attribute that has class weights as their respective frequency inversed

In [None]:
class retinopathy_dataset(Data.Dataset):
    '''DIABETIC RETINOPATHY dataset downloaded from
    kaggle link : https://www.kaggle.com/c/diabetic-retinopathy-detection/data
    root : location of data files
    train : whether training dataset (True) or test (False)
    transform : torch image transformations
    binary : whether healthy '0' vs damaged '1,2,3,4' binary detection (True) 
            or multiclass (False)
    balance: whether to balance the classes, if yes then attribute weights and
            sample weights will be calculated
    '''

    def __init__(self, root, train, transform, binary=True, balance=True):
        root += 'DIABETIC_RETINOPATHY_CRPD/'
        if train:
            self.img_dir = root + 'train/'
            label_csv = root + 'trainLabels.csv'
            with open(label_csv, 'r') as label_file:
                label_tuple = [line.strip().split(',')[:2] for line in label_file.readlines()[1:]]
            self.imgs = [item[0] for item in label_tuple]
            self.labels = [int(item[1]) for item in label_tuple]

            with open(label_csv.replace('train', 'test'), 'r') as label_file:
                label_tuple = [line.strip().split(',')[:2] for line in label_file.readlines()[1:]]
            self.imgs += [item[0] for item in label_tuple[:30000]]
            self.labels += [int(item[1]) for item in label_tuple[:30000]]

        else:
            self.img_dir = root + 'test/'
            label_csv = root + 'testLabels.csv'
            with open(label_csv, 'r') as label_file:
                label_tuple = [line.strip().split(',')[:2] for line in label_file.readlines()[1:]]
            self.imgs = [item[0] for item in label_tuple[30000:]]
            self.labels = [int(item[1]) for item in label_tuple[30000:]]

        self.transform = transform
        self.binary = binary
        if self.binary:
            self.labels = [min(label, 1) for label in self.labels]

        # Discard bad images
        bad_images = ['10_left']
        for img in bad_images:
            if img in self.imgs:
                index = self.imgs.index(img)
                self.imgs = self.imgs[:index] + self.imgs[index+1:]
                self.labels = self.labels[:index] + self.labels[index+1:]

        # Make all these better
        classes, counts = np.unique(np.array(self.labels), return_counts=True)
        deviation = np.std(counts) / np.mean(counts)
        if deviation > 0.05 and train:
            weights = 1./torch.tensor(counts, dtype=torch.float)
            weights = weights / weights.sum()
            self.weights = weights.numpy().tolist()
#             self.sample_weights = weights[self.labels]
            print('Class weights calculated as ', dict(zip(classes, weights.numpy())))

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        try:
            image = Image.open(self.img_dir + img_name + '.jpeg')
        except:
            image = Image.open(self.img_dir.replace('train', 'test') + img_name + '.jpeg')
        image = self.transform(image)
        label = self.labels[idx]

        return image, label

Dataloader and transformations

In [None]:
directory = '/data02/'
transform = transforms.Compose([transforms.Resize((512, 512)),
                                transforms.ColorJitter(brightness=(0.7, 1.3), contrast=(0.7, 1.3)),
                                transforms.ToTensor()])
trainData = retinopathy_dataset(root=directory, train=True, transform=transform, binary=True, balance=True)
trainloader = Data.DataLoader(trainData, batch_size=200, shuffle=True, num_workers=10)
print("Training dataset length:", len(trainloader.dataset))
testData = retinopathy_dataset(root=directory, train=False, transform=transform, binary=True, balance=True)
testloader = Data.DataLoader(testData, batch_size=50, shuffle=True, num_workers=10)
print("Test dataset length:", len(testloader.dataset))

In [None]:
def learning_rate_mod_factor(epoch, lr_init, lr_end, end_epoch):
    lr_ratio = lr_end / lr_init
    t = (epoch) / (end_epoch*1.0)
    if t < 0.4:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.4) / 0.5
    else:
        factor = lr_ratio
    return factor

Model is taken from 'models/diab_retin_kaggle.py'

In [None]:
device = 'cuda:1'
num_classes = 2
weight_decay = 0.0003

final_model = DiabRetinModelSimpleR256()
final_model.to(device)

To visualize images generated after transformation

In [None]:
image_iterator = iter(trainloader)
inputs, labels = image_iterator.__next__()
img_np = inputs.cpu().numpy()
plt.imshow(img_np[0].transpose(1, 2, 0))

LR Finder cell, can switch to Adam if needed, but the loss jumps to 1e17 if Adam is used

In [None]:
optimizer = torch.optim.SGD(final_model.parameters(), weight_decay=weight_decay, lr=1e-5, momentum=0.6)
class_weights = torch.tensor(trainData.weights).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
lr_finder = LRFinder(model=final_model, optimizer=optimizer, criterion=criterion, device=device)
lr_finder.range_test(trainloader, end_lr=10, num_iter=200, step_mode="exp")
lr_finder.plot(fname='lr_probing.pdf')

Training step

In [None]:
running_loss = 0
epoch_count = 20
lr = 0.01
end_lr = 0.0001
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(final_model.parameters(), weight_decay=weight_decay, lr=lr, momentum=0.6)
criterion.to(device)
# lr_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.001, max_lr=0.08, step_size_up=5, step_size_down=10)

for epoch in range(0, epoch_count):  # loop over the dataset multiple times

    factor = learning_rate_mod_factor(epoch, lr, end_lr, epoch_count)
    for i, g in enumerate(optimizer.param_groups):
        print("Learning rate for param %d is currently %.4f" %(i, g['lr']))
        g['lr'] = lr * factor

    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = final_model(inputs)
        loss = criterion(outputs, labels)
        loss.sum().backward()
        optimizer.step()
        running_loss = 0.9*running_loss + 0.1*loss.item() if running_loss != 0 else loss.item()

    if i% (len(trainloader) // 10) == 0:
        print('[%d, %5d] loss: %.4f' %(epoch + 1, i, running_loss))
    
    print("=== Accuracy using SGD params ===")
    accuracy, ece, sce = nbutils.validate(model=final_model, dataloader=testloader, device=device)