In [1]:
from tqdm import tqdm
import numpy as np
import torch
import time
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from dataset import *
from utils import Transform3D, model_to_syncbn
from models import ResNet18, ResNet50
from acsconv.converters import ACSConverter, Conv2_5dConverter, Conv3dConverter

import medmnist
from medmnist import INFO, Evaluator
from medmnist import OrganMNIST3D

from torchvision.models import resnet18, resnet50
from torchvision.models import swin_v2_t

The ``converters`` are currently experimental. It may not support operations including (but not limited to) Functions in ``torch.nn.functional`` that involved data dimension


In [2]:
data_flag = 'organmnist3d'
download = True

DEVICE = 'cpu'
NUM_EPOCHS = 25
BATCH_SIZE = 16
milestones = [0.5 * NUM_EPOCHS, 0.75 * NUM_EPOCHS]
lr = 0.001
gamma = 0.1

output_root = './output3d'

info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = len(info['label'])
task = info['task']

train_loader, train_loader_at_eval, val_loader, test_loader = extract_data_3d(data_flag, download, BATCH_SIZE)

model = resnet18(pretrained=True)
model.fc.out_features = n_classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
model.name = 'resnet18'

# model = resnet50(pretrained=True)
# model.fc.out_features = n_classes
# num_ftrs = model.fc.in_features
# model.fc = nn.Linear(num_ftrs, n_classes)
# model.name = 'resnet50'

# model = swin_v2_t(pretrained=True)
# model.head.out_features = n_classes
# num_ftrs = model.head.in_features
# model.head = nn.Linear(num_ftrs, n_classes)
# model.name = 'resnet50'

Using downloaded and verified file: /Users/vemundlund/.medmnist/organmnist3d.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/organmnist3d.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/organmnist3d.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/organmnist3d.npz
==> Building and training model...




In [3]:
conv = 'Conv3d'
pretrained_3d = 'i3d'
model_flag = 'resnet18'

# if model_flag == 'resnet18':
#     model = ResNet18(in_channels=n_channels, num_classes=n_classes)
# elif model_flag == 'resnet50':
#     model = ResNet50(in_channels=n_channels, num_classes=n_classes)
# else:
#     raise NotImplementedError

if conv=='ACSConv':
    model = model_to_syncbn(ACSConverter(model))
if conv=='Conv2_5d':
    model = model_to_syncbn(Conv2_5dConverter(model))
if conv=='Conv3d':
    if pretrained_3d == 'i3d':
        model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3))
    else:
        model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=None))

model = model.to(DEVICE)

train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')

criterion = nn.CrossEntropyLoss()


optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

logs = ['loss', 'auc', 'acc']
train_logs = ['train_'+log for log in logs]
val_logs = ['val_'+log for log in logs]
test_logs = ['test_'+log for log in logs]
log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)

writer = SummaryWriter(log_dir=os.path.join(output_root, 'Tensorboard_Results_3D'))

best_auc = 0
best_epoch = 0
best_model = deepcopy(model)
tb_twod = False

global iteration
iteration = 0

for epoch in range(NUM_EPOCHS):
    train_loss = train(model, train_loader, task, criterion, optimizer, DEVICE, writer)

    train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model3d')
    val_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model3d')
    test_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model3d')
    
    scheduler.step()
    
    for i, key in enumerate(train_logs):
        log_dict[key] = train_metrics[i]
    for i, key in enumerate(val_logs):
        log_dict[key] = val_metrics[i]
    for i, key in enumerate(test_logs):
        log_dict[key] = test_metrics[i]

    for key, value in log_dict.items():
        writer.add_scalar(key, value, epoch)
        
    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = deepcopy(model)

        print('cur_best_auc:', best_auc)
        print('cur_best_epoch', best_epoch)

state = {
    'net': model.state_dict(),
}

path = os.path.join(output_root, 'best_model.pth')
torch.save(state, path)

train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model3d')
val_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model3d')
test_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model3d')

train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

log = '%s\n' % (data_flag) + train_log + val_log + test_log + '\n'
print(log)

with open(os.path.join(output_root, '%s_log.txt' % (data_flag)), 'a') as f:
    f.write(log)        
        
writer.close()

train: 100%|██████████| 61/61 [00:28<00:00,  2.15it/s]


Done with batches
Epoch loss 2.4378304461963842


test: 100%|██████████| 61/61 [00:08<00:00,  6.97it/s]
test: 100%|██████████| 61/61 [00:08<00:00,  6.92it/s]
test: 100%|██████████| 61/61 [00:09<00:00,  6.61it/s]


cur_best_auc: 0.8579103087707044
cur_best_epoch 0


train: 100%|██████████| 61/61 [00:27<00:00,  2.21it/s]


Done with batches
Epoch loss 2.0045319541555937


test: 100%|██████████| 61/61 [00:08<00:00,  6.93it/s]
test: 100%|██████████| 61/61 [00:08<00:00,  7.00it/s]
test: 100%|██████████| 61/61 [00:08<00:00,  7.05it/s]
train: 100%|██████████| 61/61 [00:27<00:00,  2.21it/s]


Done with batches
Epoch loss 1.9154191837936152


test: 100%|██████████| 61/61 [00:09<00:00,  6.62it/s]
test: 100%|██████████| 61/61 [00:08<00:00,  6.88it/s]
test: 100%|██████████| 61/61 [00:09<00:00,  6.60it/s]


cur_best_auc: 0.9006363303784777
cur_best_epoch 2


train:  67%|██████▋   | 41/61 [00:19<00:09,  2.06it/s]


KeyboardInterrupt: 

In [None]:
# %tensorboard --logdir logs
%load_ext tensorboard


In [None]:
from tensorboard import notebook
notebook.list() # View open TensorBoard instances

No known TensorBoard instances running.
