In [None]:
import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import PIL
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models

from ds import *
from networks import *
from utils import *

## Get data loaders

In [None]:
train_tfm = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ColorJitter(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
])
train_dset = CxVAE_retino_Dset(
    csv_file='<path to>/retinopathy_small_balanced/balanced_binary_train_labels.csv', 
    root_dir='<path to>/retinopathy_small_balanced/balanced_subset_train3000_val100_test400/',
    tfm=train_tfm
)
val_dset = CxVAE_retino_Dset(
    csv_file='<path to>/retinopathy_small_balanced/balanced_binary_val_labels.csv', 
    root_dir='<path to>/retinopathy_small_balanced/balanced_subset_train3000_val100_test400/'
)
test_dset = CxVAE_retino_Dset(
    csv_file='<path to>/retinopathy_small_balanced/balanced_binary_test_labels.csv', 
    root_dir='<path to>/retinopathy_small_balanced/balanced_subset_train3000_val100_test400/'
)

train_loader = DataLoader(train_dset, batch_size=2, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dset, batch_size=2, shuffle=False, num_workers=8, pin_memory=False)
test_loader = DataLoader(test_dset, batch_size=2, shuffle=False, num_workers=8, pin_memory=False)

## Define model and pass to the training loop

In [None]:
Net = models.resnet50(pretrained=True, progress=False)
print(Net)

# Freeze training for all layers
for param in Net.parameters():
    param.require_grad = False

# Newly created modules have require_grad=True by default
num_features = Net.fc.in_features
# fc_new = torch.nn.Linear(num_features, 5)
fc_new = torch.nn.Linear(num_features, 2)
Net.fc = fc_new
print(Net)

In [None]:
train_classifier_loop(
    train_loader,
    val_loader,
    Net,
    n_epochs=100,
    init_lr=1e-5,
    eval_every = 5,
    dtype = torch.cuda.FloatTensor,
    device='cuda',
    ckpt_path = '../ckpt/ResNet50_retino'
)

In [None]:
Net.load_state_dict(torch.load('../ckpt/ResNet50_retino_best.pth'))
eval_classifier_loop(
    test_loader,
    Net,
    dtype = torch.cuda.FloatTensor,
    device='cuda',
)