# Generate ResNet18 Kermany Pretrained

## Init

In [1]:
# Auto reload
%load_ext autoreload
%autoreload 2

import torch
from kermany_data import KermanyDataset
import torchvision
from train import train, train_loop
import wandb
from sliver_net.data import build_volume_cache, E2ETileDataset

wandb.login()
wandb.init()

resize_transform = torchvision.transforms.Resize((256, 256))

resnet18 = torchvision.models.resnet18(num_classes=4, pretrained=False).cuda()

criterion = torch.nn.functional.cross_entropy
optimizer = torch.optim.Adam(resnet18.parameters(), lr=1e-4)

[34m[1mwandb[0m: Currently logged in as: [33msgvdan[0m (use `wandb login --relogin` to force relogin)


## Train on Kermany

In [2]:
val_dataset = KermanyDataset('./kermany/val', transform=resize_transform)  # (https://www.kaggle.com/paultimothymooney/kermany2018)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=64, shuffle=True)

train_dataset = KermanyDataset('./kermany/train', transform=resize_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

train(resnet18, criterion, optimizer, train_loader, val_loader, 7, 'cuda', 'ResNet18-Kermany')

# Save temporarily, so that the model won't disappear
tmp_model_name = './tmp_model.pth'
torch.save({"model_state_dict": resnet18.state_dict(),
            "optimizer_state_dict": optimizer.state_dict()
            }, tmp_model_name)

print('Done Training.')

100%|██████████| 7/7 [55:10<00:00, 472.86s/it]


Done Training.


## Evaluate on Kermany

In [3]:
model_path = './tmp_model.pth' #input("Model to load?:")
states_dict = torch.load(model_path)

resnet18.load_state_dict(states_dict['model_state_dict'])
optimizer.load_state_dict(states_dict['optimizer_state_dict'])

test_dataset = KermanyDataset('./kermany/test', transform=resize_transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=True)

running_test_accuracy = 0.0
for test_images, test_labels in test_loader:
    # test accuracy of batch
    _, accuracy = train_loop(model=resnet18, criterion=criterion, optimizer=optimizer,
                             device='cuda', images=test_images, labels=test_labels, mode='eval')
    running_test_accuracy += accuracy

test_accuracy = round(running_test_accuracy/len(test_loader), 2)
print("Test/accuracy: {}".format(test_accuracy))
wandb.log({'Kermany2D/Test/accuracy': test_accuracy})

Test/accuracy: 0.98


# Fine-tune on OUR Dataset

## Build tile dataset

In [4]:
from playground.cache import Cache
from sliver_net.data import build_volume_cache

LABELS = {'HEALTHY': torch.nn.functional.one_hot(torch.tensor(0), 2), 'SICK': torch.nn.functional.one_hot(torch.tensor(1), 2)}
build_from_scratch = False

volume_test_cache = Cache('volume_test_set')
if build_from_scratch:
    build_volume_cache(volume_test_cache, '../../OCT-DL/Data/test/control', LABELS['HEALTHY'])
    build_volume_cache(volume_test_cache, '../../OCT-DL/Data/test/study', LABELS['SICK'])

volume_train_cache = Cache('volume_train_set')
if build_from_scratch:
    build_volume_cache(volume_train_cache, '../../OCT-DL/Data/train/control', LABELS['HEALTHY'])
    build_volume_cache(volume_train_cache, '../../OCT-DL/Data/train/study', LABELS['SICK'])

volume_validation_cache = Cache('volume_validation_set')
if build_from_scratch:
    build_volume_cache(volume_validation_cache, '../../OCT-DL/Data/validation/control', LABELS['HEALTHY'])
    build_volume_cache(volume_validation_cache, '../../OCT-DL/Data/validation/study', LABELS['SICK'])

if build_from_scratch:
    print('Built tile dataset.')
else:
    print('Loaded tile dataset.')

Loaded tile dataset.


# Load SliverNet

In [5]:
from tqdm.asyncio import tqdm
from playground.cache import Cache
from sliver_net.model import load_backbone
from sliver_net.model import SliverNet2

volume_train_cache = Cache('volume_train_set')
tile_train_dataset = E2ETileDataset(volume_train_cache, transform=resize_transform)
tile_train_loader = torch.utils.data.DataLoader(dataset=tile_train_dataset, batch_size=1, shuffle=True)

volume_validation_cache = Cache('volume_validation_set')
tile_validation_dataset = E2ETileDataset(volume_validation_cache, transform=resize_transform)
tile_validation_loader = torch.utils.data.DataLoader(dataset=tile_validation_dataset, batch_size=1, shuffle=True)

# Load Model
backbone = load_backbone("sgvdan-kermany").cuda() # randomly initialize a resnet18 backbone
sliver_model = SliverNet2(backbone, n_out=2).cuda() # create SLIVER-Net with n_out outputs

print('SliverNet model Loaded!')

Loading sgvdan-kermany model
SliverNet model Loaded!


## Fine Tune last bit of SliverNet on Hadassah tile training set

In [6]:
# First: fine tune on everything but the backbone
epochs = 3
print('Fine tune on last bit of SliverNet for epochs={epochs}'.format(epochs=epochs))
for param in sliver_model.backbone.parameters():
    param.requires_grad = False
optimizer = torch.optim.Adam(sliver_model.parameters(), lr=1e-4)
train(sliver_model, sliver_model.loss_func, optimizer, tile_train_loader, tile_validation_loader, epochs, 'cuda', 'SliverNet/partial')

Fine tune on last bit of SliverNet for epochs=3


100%|██████████| 3/3 [15:02<00:00, 300.85s/it]


## Fine Tune whole of SliverNet on Hadassah tile training set

In [7]:
epochs = 3
print('Fine tune on whole of SliverNet for epochs={epochs}'.format(epochs=epochs))
for param in sliver_model.backbone.parameters():
    param.requires_grad = True
optimizer = torch.optim.Adam(sliver_model.parameters(), lr=1e-5)
train(sliver_model, sliver_model.loss_func, optimizer, tile_train_loader, tile_validation_loader, epochs, 'cuda', 'SliverNet/full')

Fine tune on whole of SliverNet for epochs=3


100%|██████████| 3/3 [16:57<00:00, 339.10s/it]


In [8]:
# Save model
tmp_model_name = './tmp_model.pth'
torch.save({"model_state_dict": sliver_model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict()
            }, tmp_model_name)
print('Done. Saved model: {}'.format(tmp_model_name))

Done. Saved model: ./tmp_model.pth


## Evaluate on Hadassah tile test set

In [9]:
tile_test_dataset = E2ETileDataset(volume_test_cache, transform=resize_transform)
tile_test_loader = torch.utils.data.DataLoader(dataset=tile_test_dataset, batch_size=1, shuffle=True)

running_test_accuracy = 0.0
for test_images, test_labels in tqdm(tile_test_loader):
    # test accuracy of batch
    _, accuracy = train_loop(model=sliver_model, criterion=sliver_model.loss_func, optimizer=optimizer, device='cuda',
                             images=test_images, labels=test_labels, mode='eval')
    running_test_accuracy += accuracy

test_accuracy = round(running_test_accuracy/len(tile_test_loader), 2)
print("Evaluated on Test Set. Accuracy: {}".format(test_accuracy))
wandb.log({'SliverNet/Test/accuracy': test_accuracy})


100%|██████████| 116/116 [00:31<00:00,  3.68it/s]

Evaluated on Test Set. Accuracy: 0.91



