# Generate ResNet18 Kermany Pretrained

## Init

In [18]:
# Auto reload
%load_ext autoreload
%autoreload 2

import torch
from kermany_data import KermanyDataset
import sys
sys.path.append('/home/projects/ronen/sgvdan/workspace/sliver_net')

import torchvision
from train import train, train_loop
import wandb
from sliver_net.data import build_volume_cache, E2ETileDataset

wandb.login()
wandb.init()

resize_transform = torchvision.transforms.Resize((256, 256))

resnet18 = torchvision.models.resnet18(num_classes=4, pretrained=False).cuda()

criterion = torch.nn.functional.cross_entropy
optimizer = torch.optim.Adam(resnet18.parameters(), lr=1e-4)



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Train on Kermany

In [4]:
from kermany_data import KERMANY_LABELS
from sliver_net import data
data.LABELS = KERMANY_LABELS

val_dataset = KermanyDataset('./kermany/val', transform=resize_transform)  # (https://www.kaggle.com/paultimothymooney/kermany2018)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=64, shuffle=True)

train_dataset = KermanyDataset('./kermany/train', transform=resize_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

train(resnet18, criterion, optimizer, train_loader, val_loader, 7, 'cuda', 'ResNet18-Kermany')

# Save temporarily, so that the model won't disappear
tmp_model_name = './tmp_model.pth'
torch.save({"model_state_dict": resnet18.state_dict(),
            "optimizer_state_dict": optimizer.state_dict()
            }, tmp_model_name)

print('Done Training.')

 14%|██████████████▏                                                                                    | 1/7 [10:04<1:00:25, 604.33s/it]

Average Accuracy:0.875


 29%|████████████████████████████▊                                                                        | 2/7 [18:04<44:15, 531.07s/it]

Average Accuracy:0.84375


 43%|███████████████████████████████████████████▎                                                         | 3/7 [26:09<34:00, 510.20s/it]

Average Accuracy:0.96875


 57%|█████████████████████████████████████████████████████████▋                                           | 4/7 [34:12<24:58, 499.57s/it]

Average Accuracy:1.0


 71%|████████████████████████████████████████████████████████████████████████▏                            | 5/7 [42:09<16:22, 491.30s/it]

Average Accuracy:0.96875


 86%|██████████████████████████████████████████████████████████████████████████████████████▌              | 6/7 [50:04<08:05, 485.88s/it]

Average Accuracy:1.0


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [58:01<00:00, 497.31s/it]

Average Accuracy:1.0





Done Training.


## Evaluate on Kermany

In [19]:
from kermany_data import KERMANY_LABELS
from sliver_net import data
from train import evaluate
data.LABELS = KERMANY_LABELS

model_path = input("Model to load?:")
states_dict = torch.load(model_path)

resnet18.load_state_dict(states_dict['model_state_dict'])
optimizer.load_state_dict(states_dict['optimizer_state_dict'])

test_dataset = KermanyDataset('./kermany/test', transform=resize_transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=True)

avg_iou, avg_accuracy = evaluate(resnet18, test_loader, 'Kermany2D/Test', 'cuda')

print('Average IOU:{}, Accuracy: {}'.format(avg_iou, avg_accuracy))

Model to load?:resnet18.pth
Average IOU:0.9958846569061279, Accuracy: 0.9979338645935059


# Fine-tune on OUR Dataset

## Build tile dataset

In [13]:
from playground.cache import Cache
from sliver_net.data import build_volume_cache

build_from_scratch = False

volume_test_cache = Cache('volume_test_set')
if build_from_scratch:
    build_volume_cache(volume_test_cache, '../../OCT-DL/Data/test/control', LABELS['HEALTHY'])
    build_volume_cache(volume_test_cache, '../../OCT-DL/Data/test/study', LABELS['SICK'])

volume_train_cache = Cache('volume_train_set')
if build_from_scratch:
    build_volume_cache(volume_train_cache, '../../OCT-DL/Data/train/control', LABELS['HEALTHY'])
    build_volume_cache(volume_train_cache, '../../OCT-DL/Data/train/study', LABELS['SICK'])

volume_validation_cache = Cache('volume_validation_set')
if build_from_scratch:
    build_volume_cache(volume_validation_cache, '../../OCT-DL/Data/validation/control', LABELS['HEALTHY'])
    build_volume_cache(volume_validation_cache, '../../OCT-DL/Data/validation/study', LABELS['SICK'])

if build_from_scratch:
    print('Built tile dataset.')
else:
    print('Loaded tile dataset.')

Loaded tile dataset.


# Load SliverNet

In [14]:
from tqdm.asyncio import tqdm
from playground.cache import Cache
from sliver_net.model import load_backbone
from sliver_net.model import SliverNet2

volume_train_cache = Cache('volume_train_set')
tile_train_dataset = E2ETileDataset(volume_train_cache, transform=resize_transform)
tile_train_loader = torch.utils.data.DataLoader(dataset=tile_train_dataset, batch_size=1, shuffle=True)

volume_validation_cache = Cache('volume_validation_set')
tile_validation_dataset = E2ETileDataset(volume_validation_cache, transform=resize_transform)
tile_validation_loader = torch.utils.data.DataLoader(dataset=tile_validation_dataset, batch_size=1, shuffle=True)

# Load Model
backbone = load_backbone("sgvdan-kermany").cuda() # randomly initialize a resnet18 backbone
sliver_model = SliverNet2(backbone, n_out=2).cuda() # create SLIVER-Net with n_out outputs

print('SliverNet model Loaded!')

Loading sgvdan-kermany model
SliverNet model Loaded!


## Fine Tune last bit of SliverNet on Hadassah tile training set

In [20]:
from sliver_net import data
data.LABELS = {'HEALTHY': torch.nn.functional.one_hot(torch.tensor(0), 2), 'SICK': torch.nn.functional.one_hot(torch.tensor(1), 2)}

# First: fine tune on everything but the backbone
epochs = 3
print('Fine tune on last bit of SliverNet for epochs={epochs}'.format(epochs=epochs))
sliver_model.backbone.eval()
for param in sliver_model.backbone.parameters():
    param.requires_grad = False
optimizer = torch.optim.Adam(sliver_model.parameters(), lr=1e-4)
train(sliver_model, sliver_model.loss_func, optimizer, tile_train_loader, tile_validation_loader, epochs, 'cuda', 'SliverNet/partial')

Fine tune on last bit of SliverNet for epochs=3


 33%|█████████████████████████████████▋                                                                   | 1/3 [04:51<09:42, 291.16s/it]

Average IOU:0.534787118434906, Accuracy:0.6129870414733887


 67%|███████████████████████████████████████████████████████████████████▎                                 | 2/3 [09:48<04:54, 294.52s/it]

Average IOU:0.3736386299133301, Accuracy:0.737500011920929


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [14:38<00:00, 292.82s/it]

Average IOU:0.46190476417541504, Accuracy:0.798214316368103





## Fine Tune whole of SliverNet on Hadassah tile training set

In [21]:
from sliver_net import data
data.LABELS = {'HEALTHY': torch.nn.functional.one_hot(torch.tensor(0), 2), 'SICK': torch.nn.functional.one_hot(torch.tensor(1), 2)}

epochs = 3
print('Fine tune on whole of SliverNet for epochs={epochs}'.format(epochs=epochs))
sliver_model.backbone.train()
for param in sliver_model.backbone.parameters():
    param.requires_grad = True
optimizer = torch.optim.Adam(sliver_model.parameters(), lr=1e-5)
train(sliver_model, sliver_model.loss_func, optimizer, tile_train_loader, tile_validation_loader, epochs, 'cuda', 'SliverNet/full')

Fine tune on whole of SliverNet for epochs=3


 33%|█████████████████████████████████▋                                                                   | 1/3 [05:33<11:06, 333.19s/it]

Average IOU:0.40877658128738403, Accuracy:0.762499988079071


 67%|███████████████████████████████████████████████████████████████████▎                                 | 2/3 [11:11<05:36, 336.31s/it]

Average IOU:0.5320436358451843, Accuracy:0.8410714268684387


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [16:35<00:00, 331.90s/it]

Average IOU:0.5478417277336121, Accuracy:0.8500000238418579





In [22]:
# Save model
tmp_model_name = './tmp_model.pth'
torch.save({"model_state_dict": sliver_model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict()
            }, tmp_model_name)
print('Done. Saved model: {}'.format(tmp_model_name))

Done. Saved model: ./tmp_model.pth


## Evaluate on Hadassah tile test set

In [23]:
from playground.train import evaluate
from sliver_net import data
data.LABELS = {'HEALTHY': torch.nn.functional.one_hot(torch.tensor(0), 2), 'SICK': torch.nn.functional.one_hot(torch.tensor(1), 2)}

tile_test_dataset = E2ETileDataset(volume_test_cache, transform=resize_transform)
tile_test_loader = torch.utils.data.DataLoader(dataset=tile_test_dataset, batch_size=1, shuffle=True)

avg_iou, avg_accuracy = evaluate(sliver_model, tile_test_loader, 'SliverNet/Test', device='cuda')
print('Average IOU: {}, Accuracy:{}'.format(avg_iou, avg_accuracy))

Average IOU: 0.48679453134536743, Accuracy:0.7415094375610352
