In [1]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from dataset import *
from models import ResNet18, ResNet50
import medmnist
from medmnist import INFO, Evaluator
from medmnist import BloodMNIST

import os
from collections import OrderedDict
from copy import deepcopy

import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet18, resnet34, resnet50
from torchvision.models import swin_v2_t


In [2]:
data_flag = 'bloodmnist'
download = True

DEVICE = 'mps'
NUM_EPOCHS = 30
BATCH_SIZE = 128
milestones = [0.5 * NUM_EPOCHS, 0.75 * NUM_EPOCHS]
lr = 0.001
gamma = 0.1

output_root = './output2d'

info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = len(info['label'])
task = info['task']

train_loader, train_loader_at_eval, val_loader, test_loader = extract_data(data_flag, download, BATCH_SIZE)


# TORCH 
# model = swin_v2_t(pretrained=True)
# model.head.out_features = n_classes
# num_ftrs = model.head.in_features
# model.head = nn.Linear(num_ftrs, n_classes)
# model.name = 'swinv2_t'

# model = resnet18(pretrained=True)
# model.fc.out_features = n_classes
# num_ftrs = model.fc.in_features
# model.fc = nn.Linear(num_ftrs, n_classes)
# model.name = 'resnet18'

model = resnet34(pretrained=True)
model.fc.out_features = n_classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
model.name = 'resnet34'

train_2d_model(model, DEVICE, data_flag, milestones, gamma, output_root, task, train_loader, train_loader_at_eval, val_loader, test_loader, NUM_EPOCHS, lr, True)


# model = resnet50(pretrained=True)
# model.fc.out_features = n_classes
# num_ftrs = model.fc.in_features
# model.fc = nn.Linear(num_ftrs, n_classes)
# model.name = 'resnet50'


# OWN
# model = ResNet18(in_channels=n_channels, num_classes=n_classes)

Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz


train:  82%|████████▏ | 77/94 [00:09<00:02,  8.18it/s]


KeyboardInterrupt: 

In [3]:
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()

train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

logs = ['loss', 'auc', 'acc']
train_logs = ['train_'+log for log in logs]
val_logs = ['val_'+log for log in logs]
test_logs = ['test_'+log for log in logs]
log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)

writer = SummaryWriter(log_dir=os.path.join(output_root, model.name))

best_auc = 0
best_epoch = 0
best_model = deepcopy(model)
tb_twod = True

for epoch in range(NUM_EPOCHS):        
    print(f"Epoch {epoch} of {NUM_EPOCHS}")
    train_loss = train(model, train_loader, task, criterion, optimizer, DEVICE, writer)
    
    train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model1')
    val_metrics = test(model, val_evaluator, val_loader, task, criterion, DEVICE, 'model1')
    test_metrics = test(model, test_evaluator, test_loader, task, criterion, DEVICE, 'model1')
    
    scheduler.step()
    
    for i, key in enumerate(train_logs):
        log_dict[key] = train_metrics[i]
    for i, key in enumerate(val_logs):
        log_dict[key] = val_metrics[i]
    for i, key in enumerate(test_logs):
        log_dict[key] = test_metrics[i]

    for key, value in log_dict.items():
        writer.add_scalar(key, value, epoch)
        
    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = deepcopy(model)
        print('cur_best_auc:', best_auc)
        print('cur_best_epoch', best_epoch)

state = {
    'net': best_model.state_dict(),
}

Epoch 0 of 30


train: 100%|██████████| 94/94 [00:41<00:00,  2.27it/s]


Done with batches
Epoch loss 0.4581768531748589


test: 100%|██████████| 47/47 [00:08<00:00,  5.59it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.13it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  5.74it/s]


cur_best_auc: 0.9840345281461182
cur_best_epoch 0
Epoch 1 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.37it/s]


Done with batches
Epoch loss 0.23446571581224177


test: 100%|██████████| 47/47 [00:07<00:00,  6.56it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.93it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.92it/s]


cur_best_auc: 0.990942505171432
cur_best_epoch 1
Epoch 2 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.35it/s]


Done with batches
Epoch loss 0.17934337258338928


test: 100%|██████████| 47/47 [00:07<00:00,  6.58it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.71it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.93it/s]


Epoch 3 of 30


train: 100%|██████████| 94/94 [00:37<00:00,  2.50it/s]


Done with batches
Epoch loss 0.1556876078248024


test: 100%|██████████| 47/47 [00:06<00:00,  6.89it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.17it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.89it/s]


cur_best_auc: 0.9938789817128981
cur_best_epoch 3
Epoch 4 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.41it/s]


Done with batches
Epoch loss 0.1329318735193699


test: 100%|██████████| 47/47 [00:07<00:00,  6.40it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.99it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.79it/s]


Epoch 5 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.44it/s]


Done with batches
Epoch loss 0.1232184138783115


test: 100%|██████████| 47/47 [00:07<00:00,  6.60it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.51it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.88it/s]


Epoch 6 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.46it/s]


Done with batches
Epoch loss 0.09547543876427919


test: 100%|██████████| 47/47 [00:07<00:00,  6.44it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.85it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.97it/s]


cur_best_auc: 0.9963218248278017
cur_best_epoch 6
Epoch 7 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.41it/s]


Done with batches
Epoch loss 0.08472500844521726


test: 100%|██████████| 47/47 [00:07<00:00,  6.63it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.93it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.91it/s]


cur_best_auc: 0.9966207510045169
cur_best_epoch 7
Epoch 8 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.38it/s]


Done with batches
Epoch loss 0.05883202907886911


test: 100%|██████████| 47/47 [00:07<00:00,  6.45it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.43it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.00it/s]


Epoch 9 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.40it/s]


Done with batches
Epoch loss 0.06642739164662805


test: 100%|██████████| 47/47 [00:07<00:00,  6.52it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.91it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.86it/s]


Epoch 10 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.42it/s]


Done with batches
Epoch loss 0.0519229507549329


test: 100%|██████████| 47/47 [00:07<00:00,  6.53it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.91it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.87it/s]


cur_best_auc: 0.9972654749199492
cur_best_epoch 10
Epoch 11 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.43it/s]


Done with batches
Epoch loss 0.05535496256631264


test: 100%|██████████| 47/47 [00:07<00:00,  6.66it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.30it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.79it/s]


Epoch 12 of 30


train: 100%|██████████| 94/94 [00:40<00:00,  2.31it/s]


Done with batches
Epoch loss 0.04684382564130616


test: 100%|██████████| 47/47 [00:07<00:00,  6.54it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.60it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.79it/s]


Epoch 13 of 30


train: 100%|██████████| 94/94 [00:41<00:00,  2.26it/s]


Done with batches
Epoch loss 0.03762064724070753


test: 100%|██████████| 47/47 [00:07<00:00,  6.11it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.09it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.44it/s]


Epoch 14 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.43it/s]


Done with batches
Epoch loss 0.03855591019159777


test: 100%|██████████| 47/47 [00:07<00:00,  6.66it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.45it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.79it/s]


Epoch 15 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.44it/s]


Done with batches
Epoch loss 0.01427458774947383


test: 100%|██████████| 47/47 [00:07<00:00,  6.60it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.65it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.77it/s]


cur_best_auc: 0.9980284312767113
cur_best_epoch 15
Epoch 16 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.42it/s]


Done with batches
Epoch loss 0.0043399468967765375


test: 100%|██████████| 47/47 [00:07<00:00,  6.55it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.11it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.88it/s]


cur_best_auc: 0.9981108572311519
cur_best_epoch 16
Epoch 17 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.45it/s]


Done with batches
Epoch loss 0.0036823688374653936


test: 100%|██████████| 47/47 [00:07<00:00,  6.53it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.12it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.82it/s]


cur_best_auc: 0.9982136531106904
cur_best_epoch 17
Epoch 18 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.43it/s]


Done with batches
Epoch loss 0.002629115997053227


test: 100%|██████████| 47/47 [00:07<00:00,  6.62it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.06it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.02it/s]


Epoch 19 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.43it/s]


Done with batches
Epoch loss 0.0025063517185775165


test: 100%|██████████| 47/47 [00:07<00:00,  6.63it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.02it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.78it/s]


Epoch 20 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.42it/s]


Done with batches
Epoch loss 0.0018028033661353898


test: 100%|██████████| 47/47 [00:07<00:00,  6.56it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.12it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.17it/s]


Epoch 21 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.44it/s]


Done with batches
Epoch loss 0.002147444369527759


test: 100%|██████████| 47/47 [00:07<00:00,  6.50it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.76it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.84it/s]


Epoch 22 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.41it/s]


Done with batches
Epoch loss 0.0017744434026486062


test: 100%|██████████| 47/47 [00:07<00:00,  6.49it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.94it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.92it/s]


Epoch 23 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.43it/s]


Done with batches
Epoch loss 0.0014772378245710416


test: 100%|██████████| 47/47 [00:07<00:00,  6.54it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.09it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.97it/s]


cur_best_auc: 0.9982318659782098
cur_best_epoch 23
Epoch 24 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.41it/s]


Done with batches
Epoch loss 0.0008369118828689333


test: 100%|██████████| 47/47 [00:07<00:00,  6.49it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.16it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.78it/s]


Epoch 25 of 30


train: 100%|██████████| 94/94 [00:40<00:00,  2.33it/s]


Done with batches
Epoch loss 0.0012941622887236472


test: 100%|██████████| 47/47 [00:07<00:00,  5.99it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.68it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.80it/s]


Epoch 26 of 30


train: 100%|██████████| 94/94 [00:42<00:00,  2.22it/s]


Done with batches
Epoch loss 0.001165221084405896


test: 100%|██████████| 47/47 [00:07<00:00,  6.30it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.12it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  5.90it/s]


cur_best_auc: 0.9982669635345245
cur_best_epoch 26
Epoch 27 of 30


train: 100%|██████████| 94/94 [00:43<00:00,  2.15it/s]


Done with batches
Epoch loss 0.0020970775319468925


test: 100%|██████████| 47/47 [00:07<00:00,  6.44it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.73it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.14it/s]


Epoch 28 of 30


train: 100%|██████████| 94/94 [00:42<00:00,  2.19it/s]


Done with batches
Epoch loss 0.000958896016140414


test: 100%|██████████| 47/47 [00:07<00:00,  5.96it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.63it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.15it/s]


Epoch 29 of 30


train: 100%|██████████| 94/94 [00:43<00:00,  2.17it/s]


Done with batches
Epoch loss 0.0010725041566787541


test: 100%|██████████| 47/47 [00:07<00:00,  6.54it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.60it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.37it/s]


In [4]:

path = os.path.join(output_root, f'{model.name}_best_model.pth')
torch.save(state, path)

train_metrics = test(best_model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model', output_root)
val_metrics = test(best_model, val_evaluator, val_loader, task, criterion, DEVICE, 'model', output_root)
test_metrics = test(best_model, test_evaluator, test_loader, task, criterion, DEVICE, 'model', output_root)

train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

log = '%s\n' % (data_flag) + train_log + val_log + test_log
print(log)
        
with open(os.path.join(output_root, '%s_log.txt' % (data_flag)), 'a') as f:
    f.write(log)  

writer.close()

test: 100%|██████████| 47/47 [00:07<00:00,  6.25it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.47it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.26it/s]

bloodmnist
train  auc: 1.00000  acc: 1.00000
val  auc: 0.99827  acc: 0.96262
test  auc: 0.99765  acc: 0.96054






In [5]:
data_flag = 'bloodmnist'
download = True

DEVICE = 'mps'
NUM_EPOCHS = 30
BATCH_SIZE = 128
milestones = [0.5 * NUM_EPOCHS, 0.75 * NUM_EPOCHS]
lr = 0.001
gamma = 0.1

output_root = './output2d'

info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = len(info['label'])
task = info['task']

train_loader, train_loader_at_eval, val_loader, test_loader = extract_data(data_flag, download, BATCH_SIZE)

model = resnet18(pretrained=True)
model.fc.out_features = n_classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
model.name = 'resnet18'



Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz




In [6]:
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()

train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

logs = ['loss', 'auc', 'acc']
train_logs = ['train_'+log for log in logs]
val_logs = ['val_'+log for log in logs]
test_logs = ['test_'+log for log in logs]
log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)

writer = SummaryWriter(log_dir=os.path.join(output_root, model.name))

best_auc = 0
best_epoch = 0
best_model = deepcopy(model)
tb_twod = True

for epoch in range(NUM_EPOCHS):        
    print(f"Epoch {epoch} of {NUM_EPOCHS}")
    train_loss = train(model, train_loader, task, criterion, optimizer, DEVICE, writer)
    
    train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model1')
    val_metrics = test(model, val_evaluator, val_loader, task, criterion, DEVICE, 'model1')
    test_metrics = test(model, test_evaluator, test_loader, task, criterion, DEVICE, 'model1')
    
    scheduler.step()
    
    for i, key in enumerate(train_logs):
        log_dict[key] = train_metrics[i]
    for i, key in enumerate(val_logs):
        log_dict[key] = val_metrics[i]
    for i, key in enumerate(test_logs):
        log_dict[key] = test_metrics[i]

    for key, value in log_dict.items():
        writer.add_scalar(key, value, epoch)
        
    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = deepcopy(model)
        print('cur_best_auc:', best_auc)
        print('cur_best_epoch', best_epoch)

state = {
    'net': best_model.state_dict(),
}


path = os.path.join(output_root, f'{model.name}_best_model.pth')
torch.save(state, path)

train_metrics = test(best_model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model', output_root)
val_metrics = test(best_model, val_evaluator, val_loader, task, criterion, DEVICE, 'model', output_root)
test_metrics = test(best_model, test_evaluator, test_loader, task, criterion, DEVICE, 'model', output_root)

train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

log = '%s\n' % (data_flag) + train_log + val_log + test_log
print(log)
        
with open(os.path.join(output_root, '%s_log.txt' % (data_flag)), 'a') as f:
    f.write(log)  

writer.close()

Epoch 0 of 30


train: 100%|██████████| 94/94 [00:24<00:00,  3.91it/s]


Done with batches
Epoch loss 0.4344546544425031


test: 100%|██████████| 47/47 [00:05<00:00,  7.92it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.84it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.82it/s]


cur_best_auc: 0.9848569441644346
cur_best_epoch 0
Epoch 1 of 30


train: 100%|██████████| 94/94 [00:25<00:00,  3.75it/s]


Done with batches
Epoch loss 0.2316137507874915


test: 100%|██████████| 47/47 [00:05<00:00,  9.08it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 12.45it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.42it/s]


cur_best_auc: 0.9906155822643211
cur_best_epoch 1
Epoch 2 of 30


train: 100%|██████████| 94/94 [00:25<00:00,  3.66it/s]


Done with batches
Epoch loss 0.19077993437964866


test: 100%|██████████| 47/47 [00:05<00:00,  7.90it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 12.99it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.45it/s]


cur_best_auc: 0.9943210015024965
cur_best_epoch 2
Epoch 3 of 30


train: 100%|██████████| 94/94 [00:26<00:00,  3.60it/s]


Done with batches
Epoch loss 0.15251892951733254


test: 100%|██████████| 47/47 [00:05<00:00,  7.99it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 11.47it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.62it/s]


Epoch 4 of 30


train: 100%|██████████| 94/94 [00:24<00:00,  3.87it/s]


Done with batches
Epoch loss 0.13075703410233588


test: 100%|██████████| 47/47 [00:05<00:00,  8.87it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 12.78it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.36it/s]


cur_best_auc: 0.9954547334614744
cur_best_epoch 4
Epoch 5 of 30


train: 100%|██████████| 94/94 [00:25<00:00,  3.69it/s]


Done with batches
Epoch loss 0.11446186737652789


test: 100%|██████████| 47/47 [00:05<00:00,  8.18it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 13.44it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.53it/s]


Epoch 6 of 30


train: 100%|██████████| 94/94 [00:25<00:00,  3.65it/s]


Done with batches
Epoch loss 0.09098067693412304


test: 100%|██████████| 47/47 [00:05<00:00,  7.85it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 10.59it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.15it/s]


cur_best_auc: 0.9955183864604747
cur_best_epoch 6
Epoch 7 of 30


train: 100%|██████████| 94/94 [00:25<00:00,  3.67it/s]


Done with batches
Epoch loss 0.08551610429632536


test: 100%|██████████| 47/47 [00:05<00:00,  8.25it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 12.60it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.31it/s]


Epoch 8 of 30


train: 100%|██████████| 94/94 [00:25<00:00,  3.69it/s]


Done with batches
Epoch loss 0.0811936145529468


test: 100%|██████████| 47/47 [00:05<00:00,  8.63it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 12.67it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.16it/s]


Epoch 9 of 30


train: 100%|██████████| 94/94 [00:26<00:00,  3.60it/s]


Done with batches
Epoch loss 0.07512779543770755


test: 100%|██████████| 47/47 [00:05<00:00,  7.90it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 11.81it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.19it/s]


cur_best_auc: 0.9966339419830779
cur_best_epoch 9
Epoch 10 of 30


train: 100%|██████████| 94/94 [00:24<00:00,  3.81it/s]


Done with batches
Epoch loss 0.07023204936388325


test: 100%|██████████| 47/47 [00:05<00:00,  8.44it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 12.26it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.06it/s]


Epoch 11 of 30


train: 100%|██████████| 94/94 [00:26<00:00,  3.57it/s]


Done with batches
Epoch loss 0.06606452079846505


test: 100%|██████████| 47/47 [00:05<00:00,  8.77it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 11.36it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.29it/s]


Epoch 12 of 30


train: 100%|██████████| 94/94 [00:26<00:00,  3.51it/s]


Done with batches
Epoch loss 0.05024395492720477


test: 100%|██████████| 47/47 [00:05<00:00,  8.59it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.90it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.00it/s]


Epoch 13 of 30


train: 100%|██████████| 94/94 [00:25<00:00,  3.75it/s]


Done with batches
Epoch loss 0.054726305705039426


test: 100%|██████████| 47/47 [00:05<00:00,  8.40it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 10.29it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.78it/s]


Epoch 14 of 30


train: 100%|██████████| 94/94 [00:26<00:00,  3.54it/s]


Done with batches
Epoch loss 0.054524748407779856


test: 100%|██████████| 47/47 [00:05<00:00,  8.17it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 10.86it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.09it/s]


Epoch 15 of 30


train: 100%|██████████| 94/94 [00:26<00:00,  3.58it/s]


Done with batches
Epoch loss 0.020680824180747917


test: 100%|██████████| 47/47 [00:06<00:00,  7.47it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 12.51it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.16it/s]


cur_best_auc: 0.9984240183526527
cur_best_epoch 15
Epoch 16 of 30


train: 100%|██████████| 94/94 [00:25<00:00,  3.67it/s]


Done with batches
Epoch loss 0.008572752233916656


test: 100%|██████████| 47/47 [00:06<00:00,  7.43it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 11.89it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.67it/s]


cur_best_auc: 0.9984747389678661
cur_best_epoch 16
Epoch 17 of 30


train: 100%|██████████| 94/94 [00:26<00:00,  3.56it/s]


Done with batches
Epoch loss 0.005372302458705777


test: 100%|██████████| 47/47 [00:05<00:00,  8.03it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 12.01it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.92it/s]


cur_best_auc: 0.9984908156889105
cur_best_epoch 17
Epoch 18 of 30


train: 100%|██████████| 94/94 [00:26<00:00,  3.61it/s]


Done with batches
Epoch loss 0.0037604419977721225


test: 100%|██████████| 47/47 [00:05<00:00,  8.03it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 10.88it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.17it/s]


Epoch 19 of 30


train: 100%|██████████| 94/94 [00:26<00:00,  3.52it/s]


Done with batches
Epoch loss 0.003865645024331009


test: 100%|██████████| 47/47 [00:05<00:00,  7.87it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.57it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.57it/s]


Epoch 20 of 30


train: 100%|██████████| 94/94 [00:25<00:00,  3.73it/s]


Done with batches
Epoch loss 0.0030001110779344442


test: 100%|██████████| 47/47 [00:05<00:00,  9.11it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 13.72it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 10.01it/s]


Epoch 21 of 30


train: 100%|██████████| 94/94 [00:21<00:00,  4.40it/s]


Done with batches
Epoch loss 0.003396149525020093


test: 100%|██████████| 47/47 [00:04<00:00, 10.86it/s]
test: 100%|██████████| 14/14 [00:00<00:00, 15.27it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 10.84it/s]


cur_best_auc: 0.9985287017516133
cur_best_epoch 21
Epoch 22 of 30


train: 100%|██████████| 94/94 [00:18<00:00,  4.99it/s]


Done with batches
Epoch loss 0.00409687376080175


test: 100%|██████████| 47/47 [00:04<00:00, 11.08it/s]
test: 100%|██████████| 14/14 [00:00<00:00, 15.61it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 11.35it/s]


Epoch 23 of 30


train: 100%|██████████| 94/94 [00:17<00:00,  5.24it/s]


Done with batches
Epoch loss 0.0029344330405558835


test: 100%|██████████| 47/47 [00:04<00:00, 11.31it/s]
test: 100%|██████████| 14/14 [00:00<00:00, 16.19it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 11.68it/s]


Epoch 24 of 30


train: 100%|██████████| 94/94 [00:17<00:00,  5.39it/s]


Done with batches
Epoch loss 0.0022344725378950006


test: 100%|██████████| 47/47 [00:04<00:00, 11.62it/s]
test: 100%|██████████| 14/14 [00:00<00:00, 16.68it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 12.02it/s]


Epoch 25 of 30


train: 100%|██████████| 94/94 [00:17<00:00,  5.42it/s]


Done with batches
Epoch loss 0.0015705677164697454


test: 100%|██████████| 47/47 [00:04<00:00, 11.57it/s]
test: 100%|██████████| 14/14 [00:00<00:00, 16.95it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 11.99it/s]


cur_best_auc: 0.9985363852121187
cur_best_epoch 25
Epoch 26 of 30


train: 100%|██████████| 94/94 [00:17<00:00,  5.42it/s]


Done with batches
Epoch loss 0.0011431714026796126


test: 100%|██████████| 47/47 [00:03<00:00, 12.62it/s]
test: 100%|██████████| 14/14 [00:00<00:00, 15.23it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 11.11it/s]


Epoch 27 of 30


train: 100%|██████████| 94/94 [00:17<00:00,  5.30it/s]


Done with batches
Epoch loss 0.0016668276114982523


test: 100%|██████████| 47/47 [00:04<00:00, 11.28it/s]
test: 100%|██████████| 14/14 [00:00<00:00, 15.26it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 11.74it/s]


Epoch 28 of 30


train: 100%|██████████| 94/94 [00:17<00:00,  5.24it/s]


Done with batches
Epoch loss 0.001597073323593693


test: 100%|██████████| 47/47 [00:04<00:00, 11.19it/s]
test: 100%|██████████| 14/14 [00:00<00:00, 14.88it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.70it/s]


Epoch 29 of 30


train: 100%|██████████| 94/94 [00:19<00:00,  4.74it/s]


Done with batches
Epoch loss 0.0019709859222162545


test: 100%|██████████| 47/47 [00:04<00:00, 11.17it/s]
test: 100%|██████████| 14/14 [00:00<00:00, 15.06it/s]
test: 100%|██████████| 14/14 [00:01<00:00, 10.23it/s]
test: 100%|██████████| 47/47 [00:05<00:00,  9.03it/s]
test: 100%|██████████| 14/14 [00:00<00:00, 14.58it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.74it/s]


bloodmnist
train  auc: 1.00000  acc: 0.99992
val  auc: 0.99854  acc: 0.95678
test  auc: 0.99763  acc: 0.96083



In [7]:

data_flag = 'bloodmnist'
download = True

DEVICE = 'mps'
NUM_EPOCHS = 30
BATCH_SIZE = 128
milestones = [0.5 * NUM_EPOCHS, 0.75 * NUM_EPOCHS]
lr = 0.001
gamma = 0.1

output_root = './output2d'

info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = len(info['label'])
task = info['task']

train_loader, train_loader_at_eval, val_loader, test_loader = extract_data(data_flag, download, BATCH_SIZE)

model = resnet50(pretrained=True)
model.fc.out_features = n_classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
model.name = 'resnet50'

Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz




In [8]:
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()

train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

logs = ['loss', 'auc', 'acc']
train_logs = ['train_'+log for log in logs]
val_logs = ['val_'+log for log in logs]
test_logs = ['test_'+log for log in logs]
log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)

writer = SummaryWriter(log_dir=os.path.join(output_root, model.name))

best_auc = 0
best_epoch = 0
best_model = deepcopy(model)
tb_twod = True

for epoch in range(NUM_EPOCHS):        
    print(f"Epoch {epoch} of {NUM_EPOCHS}")
    train_loss = train(model, train_loader, task, criterion, optimizer, DEVICE, writer)
    
    train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model1')
    val_metrics = test(model, val_evaluator, val_loader, task, criterion, DEVICE, 'model1')
    test_metrics = test(model, test_evaluator, test_loader, task, criterion, DEVICE, 'model1')
    
    scheduler.step()
    
    for i, key in enumerate(train_logs):
        log_dict[key] = train_metrics[i]
    for i, key in enumerate(val_logs):
        log_dict[key] = val_metrics[i]
    for i, key in enumerate(test_logs):
        log_dict[key] = test_metrics[i]

    for key, value in log_dict.items():
        writer.add_scalar(key, value, epoch)
        
    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = deepcopy(model)
        print('cur_best_auc:', best_auc)
        print('cur_best_epoch', best_epoch)

state = {
    'net': best_model.state_dict(),
}


path = os.path.join(output_root, f'{model.name}_best_model.pth')
torch.save(state, path)

train_metrics = test(best_model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model', output_root)
val_metrics = test(best_model, val_evaluator, val_loader, task, criterion, DEVICE, 'model', output_root)
test_metrics = test(best_model, test_evaluator, test_loader, task, criterion, DEVICE, 'model', output_root)

train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

log = '%s\n' % (data_flag) + train_log + val_log + test_log
print(log)
        
with open(os.path.join(output_root, '%s_log.txt' % (data_flag)), 'a') as f:
    f.write(log)  

writer.close()

Epoch 0 of 30


train: 100%|██████████| 94/94 [01:07<00:00,  1.40it/s]


Done with batches
Epoch loss 0.4749878576778351


test: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  4.41it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  2.95it/s]


cur_best_auc: 0.9904587566110492
cur_best_epoch 0
Epoch 1 of 30


train: 100%|██████████| 94/94 [00:58<00:00,  1.62it/s]


Done with batches
Epoch loss 0.24241460130569784


test: 100%|██████████| 47/47 [00:14<00:00,  3.24it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  5.80it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  3.39it/s]


cur_best_auc: 0.9943615801046436
cur_best_epoch 1
Epoch 2 of 30


train: 100%|██████████| 94/94 [00:59<00:00,  1.59it/s]


Done with batches
Epoch loss 0.1755893936658159


test: 100%|██████████| 47/47 [00:14<00:00,  3.15it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  5.68it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  3.30it/s]


cur_best_auc: 0.9956904238598128
cur_best_epoch 2
Epoch 3 of 30


train: 100%|██████████| 94/94 [01:01<00:00,  1.53it/s]


Done with batches
Epoch loss 0.15825890242419344


test: 100%|██████████| 47/47 [00:14<00:00,  3.17it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  5.72it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  3.37it/s]


Epoch 4 of 30


train: 100%|██████████| 94/94 [01:02<00:00,  1.51it/s]


Done with batches
Epoch loss 0.12086110974245882


test: 100%|██████████| 47/47 [00:15<00:00,  3.09it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  5.57it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  3.28it/s]


Epoch 5 of 30


train: 100%|██████████| 94/94 [01:02<00:00,  1.51it/s]


Done with batches
Epoch loss 0.11519131355700975


test: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  5.53it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  3.28it/s]


Epoch 6 of 30


train: 100%|██████████| 94/94 [01:05<00:00,  1.44it/s]


Done with batches
Epoch loss 0.07999804921131184


test: 100%|██████████| 47/47 [00:14<00:00,  3.19it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.19it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  3.50it/s]


Epoch 7 of 30


train: 100%|██████████| 94/94 [00:58<00:00,  1.61it/s]


Done with batches
Epoch loss 0.07525763419912533


test: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  5.42it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  3.39it/s]


Epoch 8 of 30


train: 100%|██████████| 94/94 [00:51<00:00,  1.81it/s]


Done with batches
Epoch loss 0.08129607368894715


test: 100%|██████████| 47/47 [00:10<00:00,  4.29it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.59it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  4.77it/s]


Epoch 9 of 30


train: 100%|██████████| 94/94 [00:31<00:00,  3.02it/s]


Done with batches
Epoch loss 0.0758998922389397


test: 100%|██████████| 47/47 [00:09<00:00,  5.05it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.82it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  5.28it/s]


Epoch 10 of 30


train: 100%|██████████| 94/94 [00:29<00:00,  3.14it/s]


Done with batches
Epoch loss 0.05680358434650809


test: 100%|██████████| 47/47 [00:09<00:00,  4.92it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.57it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  5.07it/s]


cur_best_auc: 0.9964891493278027
cur_best_epoch 10
Epoch 11 of 30


train: 100%|██████████| 94/94 [00:32<00:00,  2.86it/s]


Done with batches
Epoch loss 0.049362083659210104


test: 100%|██████████| 47/47 [00:10<00:00,  4.66it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  9.04it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  4.89it/s]


Epoch 12 of 30


train: 100%|██████████| 94/94 [00:34<00:00,  2.72it/s]


Done with batches
Epoch loss 0.04841818471260844


test: 100%|██████████| 47/47 [00:10<00:00,  4.51it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.58it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  4.65it/s]


Epoch 13 of 30


train: 100%|██████████| 94/94 [00:47<00:00,  2.00it/s]


Done with batches
Epoch loss 0.040461111323353144


test: 100%|██████████| 47/47 [00:16<00:00,  2.85it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  4.88it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.78it/s]


Epoch 14 of 30


train: 100%|██████████| 94/94 [01:03<00:00,  1.47it/s]


Done with batches
Epoch loss 0.0351566317629941


test: 100%|██████████| 47/47 [00:13<00:00,  3.54it/s]
test: 100%|██████████| 14/14 [00:02<00:00,  6.92it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  4.06it/s]


Epoch 15 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.37it/s]


Done with batches
Epoch loss 0.012383285294376076


test: 100%|██████████| 47/47 [00:10<00:00,  4.33it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.19it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  4.58it/s]


cur_best_auc: 0.9979673047661631
cur_best_epoch 15
Epoch 16 of 30


train: 100%|██████████| 94/94 [00:37<00:00,  2.50it/s]


Done with batches
Epoch loss 0.00538193485067979


test: 100%|██████████| 47/47 [00:10<00:00,  4.34it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  8.07it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  4.52it/s]


cur_best_auc: 0.9980075036949215
cur_best_epoch 16
Epoch 17 of 30


train: 100%|██████████| 94/94 [00:38<00:00,  2.44it/s]


Done with batches
Epoch loss 0.003861676952355799


test: 100%|██████████| 47/47 [00:11<00:00,  4.23it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.84it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  4.41it/s]


cur_best_auc: 0.9980735926131157
cur_best_epoch 17
Epoch 18 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.39it/s]


Done with batches
Epoch loss 0.0025728137521543837


test: 100%|██████████| 47/47 [00:11<00:00,  4.18it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.77it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  4.37it/s]


Epoch 19 of 30


train: 100%|██████████| 94/94 [00:40<00:00,  2.33it/s]


Done with batches
Epoch loss 0.0020350473301306844


test: 100%|██████████| 47/47 [00:11<00:00,  4.14it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.75it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  4.35it/s]


Epoch 20 of 30


train: 100%|██████████| 94/94 [00:39<00:00,  2.38it/s]


Done with batches
Epoch loss 0.0016392269191899376


test: 100%|██████████| 47/47 [00:11<00:00,  4.21it/s]
test: 100%|██████████| 14/14 [00:01<00:00,  7.87it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  4.43it/s]


cur_best_auc: 0.9981387831670104
cur_best_epoch 20
Epoch 21 of 30


train: 100%|██████████| 94/94 [00:55<00:00,  1.69it/s]


Done with batches
Epoch loss 0.0017018493822530565


test: 100%|██████████| 47/47 [00:24<00:00,  1.89it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  3.81it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  2.00it/s]


cur_best_auc: 0.9981637750531535
cur_best_epoch 21
Epoch 22 of 30


train: 100%|██████████| 94/94 [01:22<00:00,  1.14it/s]


Done with batches
Epoch loss 0.0013072524488187752


test: 100%|██████████| 47/47 [00:24<00:00,  1.93it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  3.92it/s]
test: 100%|██████████| 14/14 [00:06<00:00,  2.01it/s]


Epoch 23 of 30


train: 100%|██████████| 94/94 [01:23<00:00,  1.13it/s]


Done with batches
Epoch loss 0.0009514508015622306


test: 100%|██████████| 47/47 [00:24<00:00,  1.95it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  3.85it/s]
test: 100%|██████████| 14/14 [00:06<00:00,  2.04it/s]


cur_best_auc: 0.9981986249809919
cur_best_epoch 23
Epoch 24 of 30


train: 100%|██████████| 94/94 [01:17<00:00,  1.22it/s]


Done with batches
Epoch loss 0.0018880828696289447


test: 100%|██████████| 47/47 [00:24<00:00,  1.93it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  3.87it/s]
test: 100%|██████████| 14/14 [00:06<00:00,  2.02it/s]


Epoch 25 of 30


train: 100%|██████████| 94/94 [01:16<00:00,  1.22it/s]


Done with batches
Epoch loss 0.0009294568012688963


test: 100%|██████████| 47/47 [00:24<00:00,  1.93it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  3.96it/s]
test: 100%|██████████| 14/14 [00:06<00:00,  2.03it/s]


Epoch 26 of 30


train: 100%|██████████| 94/94 [01:22<00:00,  1.14it/s]


Done with batches
Epoch loss 0.0006178963617941841


test: 100%|██████████| 47/47 [00:24<00:00,  1.95it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  3.95it/s]
test: 100%|██████████| 14/14 [00:06<00:00,  2.04it/s]


cur_best_auc: 0.9982282291787552
cur_best_epoch 26
Epoch 27 of 30


train: 100%|██████████| 94/94 [01:24<00:00,  1.11it/s]


Done with batches
Epoch loss 0.0005853050676724958


test: 100%|██████████| 47/47 [00:24<00:00,  1.95it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  3.91it/s]
test: 100%|██████████| 14/14 [00:06<00:00,  2.05it/s]


Epoch 28 of 30


train: 100%|██████████| 94/94 [01:24<00:00,  1.11it/s]


Done with batches
Epoch loss 0.0005037695093747932


test: 100%|██████████| 47/47 [00:24<00:00,  1.95it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  3.88it/s]
test: 100%|██████████| 14/14 [00:06<00:00,  2.02it/s]


Epoch 29 of 30


train: 100%|██████████| 94/94 [01:24<00:00,  1.12it/s]


Done with batches
Epoch loss 0.00043391293033021207


test: 100%|██████████| 47/47 [00:24<00:00,  1.94it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  3.91it/s]
test: 100%|██████████| 14/14 [00:06<00:00,  2.05it/s]
test: 100%|██████████| 47/47 [00:24<00:00,  1.92it/s]
test: 100%|██████████| 14/14 [00:03<00:00,  3.92it/s]
test: 100%|██████████| 14/14 [00:06<00:00,  2.02it/s]


bloodmnist
train  auc: 1.00000  acc: 1.00000
val  auc: 0.99823  acc: 0.96262
test  auc: 0.99781  acc: 0.96405



In [21]:

data_flag = 'bloodmnist'
download = True

DEVICE = 'mps'
NUM_EPOCHS = 30
BATCH_SIZE = 128
milestones = [0.5 * NUM_EPOCHS, 0.75 * NUM_EPOCHS]
lr = 0.001
gamma = 0.1

output_root = './output2d'

info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = len(info['label'])
task = info['task']

train_loader, train_loader_at_eval, val_loader, test_loader = extract_data(data_flag, download, BATCH_SIZE)

from torchvision.models import convnext_small

model = convnext_small(weights=True)
print(model.classifier)
model.classifier[2] = nn.Linear(768, n_classes)
model.name = 'convnext_small'

model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()

train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

logs = ['loss', 'auc', 'acc']
train_logs = ['train_'+log for log in logs]
val_logs = ['val_'+log for log in logs]
test_logs = ['test_'+log for log in logs]
log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)

writer = SummaryWriter(log_dir=os.path.join(output_root, model.name))

best_auc = 0
best_epoch = 0
best_model = deepcopy(model)
tb_twod = True

for epoch in range(NUM_EPOCHS):        
    print(f"Epoch {epoch} of {NUM_EPOCHS}")
    train_loss = train(model, train_loader, task, criterion, optimizer, DEVICE, writer)
    
    train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model1')
    val_metrics = test(model, val_evaluator, val_loader, task, criterion, DEVICE, 'model1')
    test_metrics = test(model, test_evaluator, test_loader, task, criterion, DEVICE, 'model1')
    
    scheduler.step()
    
    for i, key in enumerate(train_logs):
        log_dict[key] = train_metrics[i]
    for i, key in enumerate(val_logs):
        log_dict[key] = val_metrics[i]
    for i, key in enumerate(test_logs):
        log_dict[key] = test_metrics[i]

    for key, value in log_dict.items():
        writer.add_scalar(key, value, epoch)
        
    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = deepcopy(model)
        print('cur_best_auc:', best_auc)
        print('cur_best_epoch', best_epoch)

state = {
    'net': best_model.state_dict(),
}


path = os.path.join(output_root, f'{model.name}_best_model.pth')
torch.save(state, path)

train_metrics = test(best_model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model', output_root)
val_metrics = test(best_model, val_evaluator, val_loader, task, criterion, DEVICE, 'model', output_root)
test_metrics = test(best_model, test_evaluator, test_loader, task, criterion, DEVICE, 'model', output_root)

train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

log = '%s\n' % (data_flag) + train_log + val_log + test_log
print(log)
        
with open(os.path.join(output_root, '%s_log.txt' % (data_flag)), 'a') as f:
    f.write(log)  

writer.close()

Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz




Sequential(
  (0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=768, out_features=1000, bias=True)
)
Epoch 0 of 30


train:   0%|          | 0/94 [00:00<?, ?it/s]


RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (2 x 2). Kernel size can't be greater than actual input size

In [13]:

data_flag = 'bloodmnist'
download = True

DEVICE = 'mps'
NUM_EPOCHS = 30
BATCH_SIZE = 128
milestones = [0.5 * NUM_EPOCHS, 0.75 * NUM_EPOCHS]
lr = 0.001
gamma = 0.1

output_root = './output2d'

info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = len(info['label'])
task = info['task']

train_loader, train_loader_at_eval, val_loader, test_loader = extract_data(data_flag, download, BATCH_SIZE)

from torchvision.models import vit_b_16

model = vit_b_16(pretrained=True)
model.heads = nn.Linear(768, n_classes)
model.name = 'vit_b_16'

model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()

train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

logs = ['loss', 'auc', 'acc']
train_logs = ['train_'+log for log in logs]
val_logs = ['val_'+log for log in logs]
test_logs = ['test_'+log for log in logs]
log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)

writer = SummaryWriter(log_dir=os.path.join(output_root, model.name))

best_auc = 0
best_epoch = 0
best_model = deepcopy(model)
tb_twod = True

for epoch in range(NUM_EPOCHS):        
    print(f"Epoch {epoch} of {NUM_EPOCHS}")
    train_loss = train(model, train_loader, task, criterion, optimizer, DEVICE, writer)
    
    train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model1')
    val_metrics = test(model, val_evaluator, val_loader, task, criterion, DEVICE, 'model1')
    test_metrics = test(model, test_evaluator, test_loader, task, criterion, DEVICE, 'model1')
    
    scheduler.step()
    
    for i, key in enumerate(train_logs):
        log_dict[key] = train_metrics[i]
    for i, key in enumerate(val_logs):
        log_dict[key] = val_metrics[i]
    for i, key in enumerate(test_logs):
        log_dict[key] = test_metrics[i]

    for key, value in log_dict.items():
        writer.add_scalar(key, value, epoch)
        
    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = deepcopy(model)
        print('cur_best_auc:', best_auc)
        print('cur_best_epoch', best_epoch)

state = {
    'net': best_model.state_dict(),
}


path = os.path.join(output_root, f'{model.name}_best_model.pth')
torch.save(state, path)

train_metrics = test(best_model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model', output_root)
val_metrics = test(best_model, val_evaluator, val_loader, task, criterion, DEVICE, 'model', output_root)
test_metrics = test(best_model, test_evaluator, test_loader, task, criterion, DEVICE, 'model', output_root)

train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

log = '%s\n' % (data_flag) + train_log + val_log + test_log
print(log)
        
with open(os.path.join(output_root, '%s_log.txt' % (data_flag)), 'a') as f:
    f.write(log)  

writer.close()

Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz




Epoch 0 of 30


train:   0%|          | 0/94 [00:00<?, ?it/s]


AssertionError: Wrong image height! Expected 224 but got 28!

In [22]:

data_flag = 'bloodmnist'
download = True

DEVICE = 'mps'
NUM_EPOCHS = 30
BATCH_SIZE = 128
milestones = [0.5 * NUM_EPOCHS, 0.75 * NUM_EPOCHS]
lr = 0.001
gamma = 0.1

output_root = './output2d'

info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = len(info['label'])
task = info['task']

train_loader, train_loader_at_eval, val_loader, test_loader = extract_data(data_flag, download, BATCH_SIZE)

from torchvision.models import efficientnet_v2_s

model = efficientnet_v2_s(pretrained=True)
model.classifier[1] = nn.Linear(1280, n_classes)
model.name = 'efficientnet_v2_s'

model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()

train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

logs = ['loss', 'auc', 'acc']
train_logs = ['train_'+log for log in logs]
val_logs = ['val_'+log for log in logs]
test_logs = ['test_'+log for log in logs]
log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)

writer = SummaryWriter(log_dir=os.path.join(output_root, model.name))

best_auc = 0
best_epoch = 0
best_model = deepcopy(model)
tb_twod = True

for epoch in range(NUM_EPOCHS):        
    print(f"Epoch {epoch} of {NUM_EPOCHS}")
    train_loss = train(model, train_loader, task, criterion, optimizer, DEVICE, writer)
    
    train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model1')
    val_metrics = test(model, val_evaluator, val_loader, task, criterion, DEVICE, 'model1')
    test_metrics = test(model, test_evaluator, test_loader, task, criterion, DEVICE, 'model1')
    
    scheduler.step()
    
    for i, key in enumerate(train_logs):
        log_dict[key] = train_metrics[i]
    for i, key in enumerate(val_logs):
        log_dict[key] = val_metrics[i]
    for i, key in enumerate(test_logs):
        log_dict[key] = test_metrics[i]

    for key, value in log_dict.items():
        writer.add_scalar(key, value, epoch)
        
    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = deepcopy(model)
        print('cur_best_auc:', best_auc)
        print('cur_best_epoch', best_epoch)

state = {
    'net': best_model.state_dict(),
}


path = os.path.join(output_root, f'{model.name}_best_model.pth')
torch.save(state, path)

train_metrics = test(best_model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model', output_root)
val_metrics = test(best_model, val_evaluator, val_loader, task, criterion, DEVICE, 'model', output_root)
test_metrics = test(best_model, test_evaluator, test_loader, task, criterion, DEVICE, 'model', output_root)

train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

log = '%s\n' % (data_flag) + train_log + val_log + test_log
print(log)
        
with open(os.path.join(output_root, '%s_log.txt' % (data_flag)), 'a') as f:
    f.write(log)  

writer.close()

Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz




Epoch 0 of 30


train: 100%|██████████| 94/94 [03:08<00:00,  2.01s/it]


Done with batches
Epoch loss 0.699226622885846


test: 100%|██████████| 47/47 [00:32<00:00,  1.43it/s]
test: 100%|██████████| 14/14 [00:08<00:00,  1.68it/s]
test: 100%|██████████| 14/14 [00:10<00:00,  1.32it/s]


cur_best_auc: 0.9921354199736113
cur_best_epoch 0
Epoch 1 of 30


train: 100%|██████████| 94/94 [02:57<00:00,  1.89s/it]


Done with batches
Epoch loss 0.24732976486074162


test: 100%|██████████| 47/47 [00:26<00:00,  1.76it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  3.05it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.83it/s]


cur_best_auc: 0.9952164848633365
cur_best_epoch 1
Epoch 2 of 30


train: 100%|██████████| 94/94 [02:58<00:00,  1.90s/it]


Done with batches
Epoch loss 0.1559623882491538


test: 100%|██████████| 47/47 [00:26<00:00,  1.74it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.62it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.80it/s]


cur_best_auc: 0.9955135620505534
cur_best_epoch 2
Epoch 3 of 30


train: 100%|██████████| 94/94 [02:58<00:00,  1.90s/it]


Done with batches
Epoch loss 0.11317271547035333


test: 100%|██████████| 47/47 [00:26<00:00,  1.75it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.72it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.80it/s]


Epoch 4 of 30


train: 100%|██████████| 94/94 [02:56<00:00,  1.88s/it]


Done with batches
Epoch loss 0.09667549191161673


test: 100%|██████████| 47/47 [00:27<00:00,  1.74it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.76it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]


cur_best_auc: 0.9955448285703988
cur_best_epoch 4
Epoch 5 of 30


train: 100%|██████████| 94/94 [02:58<00:00,  1.89s/it]


Done with batches
Epoch loss 0.06842732636575052


test: 100%|██████████| 47/47 [00:26<00:00,  1.74it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.78it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.81it/s]


cur_best_auc: 0.996798096016476
cur_best_epoch 5
Epoch 6 of 30


train: 100%|██████████| 94/94 [02:57<00:00,  1.88s/it]


Done with batches
Epoch loss 0.047978765648254686


test: 100%|██████████| 47/47 [00:27<00:00,  1.74it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.76it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.80it/s]


Epoch 7 of 30


train: 100%|██████████| 94/94 [02:56<00:00,  1.88s/it]


Done with batches
Epoch loss 0.04465303018173956


test: 100%|██████████| 47/47 [00:26<00:00,  1.75it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.75it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.80it/s]


Epoch 8 of 30


train: 100%|██████████| 94/94 [02:55<00:00,  1.87s/it]


Done with batches
Epoch loss 0.04347133204375612


test: 100%|██████████| 47/47 [00:26<00:00,  1.76it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  2.98it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.81it/s]


Epoch 9 of 30


train: 100%|██████████| 94/94 [02:56<00:00,  1.88s/it]


Done with batches
Epoch loss 0.047669640246857985


test: 100%|██████████| 47/47 [00:27<00:00,  1.71it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.70it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.78it/s]


cur_best_auc: 0.9971200348639943
cur_best_epoch 9
Epoch 10 of 30


train: 100%|██████████| 94/94 [02:55<00:00,  1.87s/it]


Done with batches
Epoch loss 0.0328751548609518


test: 100%|██████████| 47/47 [00:27<00:00,  1.72it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.78it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.77it/s]


Epoch 11 of 30


train: 100%|██████████| 94/94 [02:55<00:00,  1.86s/it]


Done with batches
Epoch loss 0.043658772951606264


test: 100%|██████████| 47/47 [00:27<00:00,  1.71it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.75it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.78it/s]


Epoch 12 of 30


train: 100%|██████████| 94/94 [02:56<00:00,  1.88s/it]


Done with batches
Epoch loss 0.02412294277158427


test: 100%|██████████| 47/47 [00:26<00:00,  1.75it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.76it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.80it/s]


Epoch 13 of 30


train: 100%|██████████| 94/94 [02:55<00:00,  1.87s/it]


Done with batches
Epoch loss 0.03199309179995288


test: 100%|██████████| 47/47 [00:26<00:00,  1.75it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  2.81it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.77it/s]


Epoch 14 of 30


train: 100%|██████████| 94/94 [02:56<00:00,  1.88s/it]


Done with batches
Epoch loss 0.023197477274732863


test: 100%|██████████| 47/47 [00:26<00:00,  1.75it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.75it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]


cur_best_auc: 0.9973058412095329
cur_best_epoch 14
Epoch 15 of 30


train: 100%|██████████| 94/94 [02:55<00:00,  1.87s/it]


Done with batches
Epoch loss 0.01059373075336322


test: 100%|██████████| 47/47 [00:27<00:00,  1.74it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.75it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]


cur_best_auc: 0.9977487620791338
cur_best_epoch 15
Epoch 16 of 30


train: 100%|██████████| 94/94 [02:57<00:00,  1.89s/it]


Done with batches
Epoch loss 0.006137219952584116


test: 100%|██████████| 47/47 [00:27<00:00,  1.74it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.75it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.77it/s]


cur_best_auc: 0.9979222493290388
cur_best_epoch 16
Epoch 17 of 30


train: 100%|██████████| 94/94 [02:56<00:00,  1.88s/it]


Done with batches
Epoch loss 0.003474655321119923


test: 100%|██████████| 47/47 [00:26<00:00,  1.75it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.70it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.75it/s]


cur_best_auc: 0.9980333574104123
cur_best_epoch 17
Epoch 18 of 30


train: 100%|██████████| 94/94 [02:57<00:00,  1.89s/it]


Done with batches
Epoch loss 0.004266105096789072


test: 100%|██████████| 47/47 [00:27<00:00,  1.70it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  2.93it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.76it/s]


Epoch 19 of 30


train: 100%|██████████| 94/94 [02:56<00:00,  1.88s/it]


Done with batches
Epoch loss 0.0024209179421032265


test: 100%|██████████| 47/47 [00:27<00:00,  1.72it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.76it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]


Epoch 20 of 30


train: 100%|██████████| 94/94 [02:56<00:00,  1.88s/it]


Done with batches
Epoch loss 0.002532624059893622


test: 100%|██████████| 47/47 [00:26<00:00,  1.74it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  2.81it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]


Epoch 21 of 30


train: 100%|██████████| 94/94 [02:58<00:00,  1.90s/it]


Done with batches
Epoch loss 0.002373281098737201


test: 100%|██████████| 47/47 [00:27<00:00,  1.69it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.74it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]


Epoch 22 of 30


train: 100%|██████████| 94/94 [02:58<00:00,  1.90s/it]


Done with batches
Epoch loss 0.0020535188672356334


test: 100%|██████████| 47/47 [00:27<00:00,  1.73it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.71it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.78it/s]


Epoch 23 of 30


train: 100%|██████████| 94/94 [02:58<00:00,  1.90s/it]


Done with batches
Epoch loss 0.001268330535086979


test: 100%|██████████| 47/47 [00:26<00:00,  1.75it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.74it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]


Epoch 24 of 30


train: 100%|██████████| 94/94 [02:57<00:00,  1.88s/it]


Done with batches
Epoch loss 0.0014134882914847967


test: 100%|██████████| 47/47 [00:26<00:00,  1.75it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.76it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]


Epoch 25 of 30


train: 100%|██████████| 94/94 [02:57<00:00,  1.89s/it]


Done with batches
Epoch loss 0.0013531768138992893


test: 100%|██████████| 47/47 [00:26<00:00,  1.78it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  3.24it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]


Epoch 26 of 30


train: 100%|██████████| 94/94 [02:54<00:00,  1.86s/it]


Done with batches
Epoch loss 0.0012615478252050803


test: 100%|██████████| 47/47 [00:27<00:00,  1.71it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.74it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.78it/s]


Epoch 27 of 30


train: 100%|██████████| 94/94 [02:54<00:00,  1.86s/it]


Done with batches
Epoch loss 0.001296003222628218


test: 100%|██████████| 47/47 [00:26<00:00,  1.75it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.64it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]


Epoch 28 of 30


train: 100%|██████████| 94/94 [02:55<00:00,  1.87s/it]


Done with batches
Epoch loss 0.0013482483401140656


test: 100%|██████████| 47/47 [00:27<00:00,  1.73it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  2.95it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.78it/s]


cur_best_auc: 0.9980650900255572
cur_best_epoch 28
Epoch 29 of 30


train: 100%|██████████| 94/94 [02:54<00:00,  1.86s/it]


Done with batches
Epoch loss 0.0017613798560036076


test: 100%|██████████| 47/47 [00:26<00:00,  1.75it/s]
test: 100%|██████████| 14/14 [00:05<00:00,  2.74it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.80it/s]
test: 100%|██████████| 47/47 [00:27<00:00,  1.74it/s]
test: 100%|██████████| 14/14 [00:04<00:00,  2.85it/s]
test: 100%|██████████| 14/14 [00:07<00:00,  1.79it/s]

bloodmnist
train  auc: 1.00000  acc: 1.00000
val  auc: 0.99807  acc: 0.96028
test  auc: 0.99770  acc: 0.96141






In [19]:

data_flag = 'bloodmnist'
download = True

DEVICE = 'mps'
NUM_EPOCHS = 30
BATCH_SIZE = 128
milestones = [0.5 * NUM_EPOCHS, 0.75 * NUM_EPOCHS]
lr = 0.001
gamma = 0.1

output_root = './output2d'

info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = len(info['label'])
task = info['task']

train_loader, train_loader_at_eval, val_loader, test_loader = extract_data(data_flag, download, BATCH_SIZE)

from torchvision.models import inception_v3

model = inception_v3(pretrained=True)
model.fc = nn.Linear(2048, n_classes)
model.name = 'inception_v3'


model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()

train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

logs = ['loss', 'auc', 'acc']
train_logs = ['train_'+log for log in logs]
val_logs = ['val_'+log for log in logs]
test_logs = ['test_'+log for log in logs]
log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)

writer = SummaryWriter(log_dir=os.path.join(output_root, model.name))

best_auc = 0
best_epoch = 0
best_model = deepcopy(model)
tb_twod = True

for epoch in range(NUM_EPOCHS):        
    print(f"Epoch {epoch} of {NUM_EPOCHS}")
    train_loss = train(model, train_loader, task, criterion, optimizer, DEVICE, writer)
    
    train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model1')
    val_metrics = test(model, val_evaluator, val_loader, task, criterion, DEVICE, 'model1')
    test_metrics = test(model, test_evaluator, test_loader, task, criterion, DEVICE, 'model1')
    
    scheduler.step()
    
    for i, key in enumerate(train_logs):
        log_dict[key] = train_metrics[i]
    for i, key in enumerate(val_logs):
        log_dict[key] = val_metrics[i]
    for i, key in enumerate(test_logs):
        log_dict[key] = test_metrics[i]

    for key, value in log_dict.items():
        writer.add_scalar(key, value, epoch)
        
    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = deepcopy(model)
        print('cur_best_auc:', best_auc)
        print('cur_best_epoch', best_epoch)

state = {
    'net': best_model.state_dict(),
}


path = os.path.join(output_root, f'{model.name}_best_model.pth')
torch.save(state, path)

train_metrics = test(best_model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model', output_root)
val_metrics = test(best_model, val_evaluator, val_loader, task, criterion, DEVICE, 'model', output_root)
test_metrics = test(best_model, test_evaluator, test_loader, task, criterion, DEVICE, 'model', output_root)

train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

log = '%s\n' % (data_flag) + train_log + val_log + test_log
print(log)
        
with open(os.path.join(output_root, '%s_log.txt' % (data_flag)), 'a') as f:
    f.write(log)  

writer.close()

Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Linear(in_features=2048, out_features=1000, bias=True)
Epoch 0 of 30


train:   0%|          | 0/94 [00:00<?, ?it/s]


RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (3 x 3). Kernel size can't be greater than actual input size

In [23]:
data_flag = 'bloodmnist'
download = True

DEVICE = 'mps'
NUM_EPOCHS = 30
BATCH_SIZE = 128
milestones = [0.5 * NUM_EPOCHS, 0.75 * NUM_EPOCHS]
lr = 0.001
gamma = 0.1

output_root = './output2d'

info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = len(info['label'])
task = info['task']

train_loader, train_loader_at_eval, val_loader, test_loader = extract_data(data_flag, download, BATCH_SIZE)

from torchvision.models import swin_v2_s

model = swin_v2_s(pretrained=True)
model.head = nn.Linear(768, n_classes)
model.name = 'swin_v2_s'

model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()

train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

logs = ['loss', 'auc', 'acc']
train_logs = ['train_'+log for log in logs]
val_logs = ['val_'+log for log in logs]
test_logs = ['test_'+log for log in logs]
log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)

writer = SummaryWriter(log_dir=os.path.join(output_root, model.name))

best_auc = 0
best_epoch = 0
best_model = deepcopy(model)
tb_twod = True

for epoch in range(NUM_EPOCHS):        
    print(f"Epoch {epoch} of {NUM_EPOCHS}")
    train_loss = train(model, train_loader, task, criterion, optimizer, DEVICE, writer)
    
    train_metrics = test(model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model1')
    val_metrics = test(model, val_evaluator, val_loader, task, criterion, DEVICE, 'model1')
    test_metrics = test(model, test_evaluator, test_loader, task, criterion, DEVICE, 'model1')
    
    scheduler.step()
    
    for i, key in enumerate(train_logs):
        log_dict[key] = train_metrics[i]
    for i, key in enumerate(val_logs):
        log_dict[key] = val_metrics[i]
    for i, key in enumerate(test_logs):
        log_dict[key] = test_metrics[i]

    for key, value in log_dict.items():
        writer.add_scalar(key, value, epoch)
        
    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = deepcopy(model)
        print('cur_best_auc:', best_auc)
        print('cur_best_epoch', best_epoch)

state = {
    'net': best_model.state_dict(),
}


path = os.path.join(output_root, f'{model.name}_best_model.pth')
torch.save(state, path)

train_metrics = test(best_model, train_evaluator, train_loader_at_eval, task, criterion, DEVICE, 'model', output_root)
val_metrics = test(best_model, val_evaluator, val_loader, task, criterion, DEVICE, 'model', output_root)
test_metrics = test(best_model, test_evaluator, test_loader, task, criterion, DEVICE, 'model', output_root)

train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

log = '%s\n' % (data_flag) + train_log + val_log + test_log
print(log)
        
with open(os.path.join(output_root, '%s_log.txt' % (data_flag)), 'a') as f:
    f.write(log)  

writer.close()

Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz
Using downloaded and verified file: /Users/vemundlund/.medmnist/bloodmnist.npz




Epoch 0 of 30


train: 100%|██████████| 94/94 [10:21<00:00,  6.61s/it]


Done with batches
Epoch loss 1.8202674820068034


test: 100%|██████████| 47/47 [02:24<00:00,  3.07s/it]
test: 100%|██████████| 14/14 [00:24<00:00,  1.73s/it]
test: 100%|██████████| 14/14 [00:43<00:00,  3.09s/it]


cur_best_auc: 0.816666601302069
cur_best_epoch 0
Epoch 1 of 30


train: 100%|██████████| 94/94 [10:12<00:00,  6.51s/it]


Done with batches
Epoch loss 1.276784487861268


test: 100%|██████████| 47/47 [02:16<00:00,  2.90s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it]
test: 100%|██████████| 14/14 [00:38<00:00,  2.79s/it]


cur_best_auc: 0.8697195492559324
cur_best_epoch 1
Epoch 2 of 30


train: 100%|██████████| 94/94 [10:12<00:00,  6.51s/it]


Done with batches
Epoch loss 1.0282108631539852


test: 100%|██████████| 47/47 [02:17<00:00,  2.92s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.42s/it]
test: 100%|██████████| 14/14 [00:39<00:00,  2.79s/it]


cur_best_auc: 0.9314671514099747
cur_best_epoch 2
Epoch 3 of 30


train: 100%|██████████| 94/94 [10:12<00:00,  6.52s/it]


Done with batches
Epoch loss 0.8134541498853806


test: 100%|██████████| 47/47 [02:16<00:00,  2.91s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it]
test: 100%|██████████| 14/14 [00:38<00:00,  2.79s/it]


cur_best_auc: 0.9503430175625864
cur_best_epoch 3
Epoch 4 of 30


train: 100%|██████████| 94/94 [10:10<00:00,  6.50s/it]


Done with batches
Epoch loss 0.6205128371081454


test: 100%|██████████| 47/47 [02:16<00:00,  2.90s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.42s/it]
test: 100%|██████████| 14/14 [00:39<00:00,  2.83s/it]


cur_best_auc: 0.9647576183777415
cur_best_epoch 4
Epoch 5 of 30


train: 100%|██████████| 94/94 [10:10<00:00,  6.49s/it]


Done with batches
Epoch loss 0.537291777577806


test: 100%|██████████| 47/47 [02:16<00:00,  2.90s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it]
test: 100%|██████████| 14/14 [00:38<00:00,  2.78s/it]


cur_best_auc: 0.9746571107750753
cur_best_epoch 5
Epoch 6 of 30


train: 100%|██████████| 94/94 [10:09<00:00,  6.48s/it]


Done with batches
Epoch loss 0.63245335031063


test: 100%|██████████| 47/47 [02:16<00:00,  2.90s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.42s/it]
test: 100%|██████████| 14/14 [00:38<00:00,  2.78s/it]


Epoch 7 of 30


train: 100%|██████████| 94/94 [10:10<00:00,  6.49s/it]


Done with batches
Epoch loss 0.49664456064396717


test: 100%|██████████| 47/47 [02:17<00:00,  2.92s/it]
test: 100%|██████████| 14/14 [00:21<00:00,  1.51s/it]
test: 100%|██████████| 14/14 [00:39<00:00,  2.80s/it]


cur_best_auc: 0.9761576186961529
cur_best_epoch 7
Epoch 8 of 30


train: 100%|██████████| 94/94 [10:10<00:00,  6.50s/it]


Done with batches
Epoch loss 0.4345693933836957


test: 100%|██████████| 47/47 [02:16<00:00,  2.90s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it]
test: 100%|██████████| 14/14 [00:38<00:00,  2.78s/it]


cur_best_auc: 0.9805573590588217
cur_best_epoch 8
Epoch 9 of 30


train: 100%|██████████| 94/94 [10:11<00:00,  6.51s/it]


Done with batches
Epoch loss 0.41346831905080916


test: 100%|██████████| 47/47 [02:17<00:00,  2.93s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.43s/it]
test: 100%|██████████| 14/14 [00:39<00:00,  2.81s/it]


Epoch 10 of 30


train: 100%|██████████| 94/94 [10:10<00:00,  6.50s/it]


Done with batches
Epoch loss 0.40748350417360346


test: 100%|██████████| 47/47 [02:16<00:00,  2.90s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.42s/it]
test: 100%|██████████| 14/14 [00:39<00:00,  2.79s/it]


cur_best_auc: 0.9823291598203977
cur_best_epoch 10
Epoch 11 of 30


train: 100%|██████████| 94/94 [10:10<00:00,  6.50s/it]


Done with batches
Epoch loss 0.4728154029617918


test: 100%|██████████| 47/47 [02:16<00:00,  2.90s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it]
test: 100%|██████████| 14/14 [00:38<00:00,  2.78s/it]


Epoch 12 of 30


train: 100%|██████████| 94/94 [10:12<00:00,  6.52s/it]


Done with batches
Epoch loss 0.42515385689887597


test: 100%|██████████| 47/47 [02:16<00:00,  2.91s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it]
test: 100%|██████████| 14/14 [00:39<00:00,  2.79s/it]


cur_best_auc: 0.982506580959852
cur_best_epoch 12
Epoch 13 of 30


train: 100%|██████████| 94/94 [10:09<00:00,  6.49s/it]


Done with batches
Epoch loss 0.36127510920484013


test: 100%|██████████| 47/47 [02:16<00:00,  2.90s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it]
test: 100%|██████████| 14/14 [00:39<00:00,  2.79s/it]


cur_best_auc: 0.9853124556808074
cur_best_epoch 13
Epoch 14 of 30


train: 100%|██████████| 94/94 [10:12<00:00,  6.51s/it]


Done with batches
Epoch loss 0.34142587555849807


test: 100%|██████████| 47/47 [02:16<00:00,  2.90s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.42s/it]
test: 100%|██████████| 14/14 [00:38<00:00,  2.79s/it]


cur_best_auc: 0.9877940438336857
cur_best_epoch 14
Epoch 15 of 30


train: 100%|██████████| 94/94 [10:11<00:00,  6.50s/it]


Done with batches
Epoch loss 0.22043304081926954


test: 100%|██████████| 47/47 [02:16<00:00,  2.91s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it]
test: 100%|██████████| 14/14 [00:39<00:00,  2.79s/it]


cur_best_auc: 0.9917914662673875
cur_best_epoch 15
Epoch 16 of 30


train: 100%|██████████| 94/94 [10:12<00:00,  6.51s/it]


Done with batches
Epoch loss 0.19251005494214118


test: 100%|██████████| 47/47 [02:17<00:00,  2.92s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.42s/it]
test: 100%|██████████| 14/14 [00:38<00:00,  2.78s/it]


cur_best_auc: 0.9924283187168882
cur_best_epoch 16
Epoch 17 of 30


train: 100%|██████████| 94/94 [10:10<00:00,  6.49s/it]


Done with batches
Epoch loss 0.17180371403377107


test: 100%|██████████| 47/47 [02:16<00:00,  2.91s/it]
test: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it]
test: 100%|██████████| 14/14 [00:39<00:00,  2.79s/it]


cur_best_auc: 0.9926174590649204
cur_best_epoch 17
Epoch 18 of 30


train:  13%|█▎        | 12/94 [13:17<58:29, 42.80s/it]  

In [None]:
%reload_ext tensorboard
%tensorboard --logdir 'output2d'
from tensorboard import notebook
notebook.list()


In [10]:
# train  auc: 0.99237  acc: 0.90894
# val  auc: 0.99129  acc: 0.90362
# test  auc: 0.99095  acc: 0.89711