In [19]:
# data
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import data_handling

# training
import model_saving, training, model_creating
import wandb
import torch, torch.nn as nn
from tqdm.auto import tqdm

In [2]:
IS_DEMO = True
config = {'device': 1, 'lr': 1e-4, 'epochs': 1000, 'plateau': False, 'save': False, 'demo': IS_DEMO}

### Data loading

In [5]:
datasets = data_handling.get_ds_names(IS_DEMO)
print('Using datasets: ', ', '.join(datasets))

idxs, imgs, diseases = data_handling.load_data(datasets)

trn_idxs, val_idxs = train_test_split(idxs, train_size=0.8, stratify=[diseases[idx] for idx in idxs])

trn_ds = data_handling.DiseaseDataset(trn_idxs, imgs, diseases)
val_ds = data_handling.DiseaseDataset(val_idxs, imgs, diseases)

trn_dl = DataLoader(trn_ds, batch_size=30)
val_dl = DataLoader(val_ds, batch_size=30)

Using datasets:  gsa
Hadling 1969 data instances


gsa:   0%|          | 0/1969 [00:00<?, ?it/s]



In [15]:
NUM_CLASSES = data_handling.NUM_DISEASES

### Training

In [10]:
run = wandb.init(project='derm-dis-morph', config=config)

[34m[1mwandb[0m: Currently logged in as: [33mtanyapole[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.31 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [11]:
run_name = f'baseline lr={run.config.lr}'
if IS_DEMO: run_name = 'demo'
run.name = run_name

In [13]:
device = torch.device(f'cuda:{run.config.device}')
loss_fn = nn.CrossEntropyLoss()
torch.set_num_threads(2)

In [16]:
model = model_creating.create_model(NUM_CLASSES).to(device)

lr = run.config.lr
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if run.config.plateau:
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5, threshold=0.001)

In [17]:
if run.config.save:
    save_folder = model_saving.get_save_folder()
    run.config.save_folder = save_folder
    model_saver = model_saving.BestModelSaver(save_folder, device)
    print(model_saver.save_fldr)

In [20]:
common_params = {'model': model, 'optimizer': optimizer, 'loss_fn': loss_fn, 'device': device}
for epoch in tqdm(list(range(run.config.epochs)), desc='Epoch'):
    D = {'epoch': epoch}
    
    losses, preds, targs = training.step(trn_dl, training.Mode.Train, 'Train', **common_params)
    metrics = training.compute_metrics('trn', losses, preds, targs)
    D = training.append_dict(D, metrics)
    
    losses, preds, targs = training.step(val_dl, training.Mode.Eval, 'Valid', **common_params)
    metrics = training.compute_metrics('val', losses, preds, targs)
    D = training.append_dict(D, metrics)
    
    wandb.log(D)
    if run.config.save: model_saver.update(model, D)
    if run.config.plateau: scheduler.step(D['val/acc'])
    # print('acc=', D['val/acc'], 'lr=', optimizer.param_groups[0]['lr'])

Epoch:   0%|          | 0/1000 [00:00<?, ?it/s]

Train:   0%|          | 0/53 [00:00<?, ?it/s]

Valid:   0%|          | 0/14 [00:00<?, ?it/s]

Train:   0%|          | 0/53 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [21]:
run.finish();

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,0.0
trn/loss,2.35606
trn/acc,0.28508
trn/f1,0.13646
val/loss,1.96498
val/acc,0.40863
val/f1,0.23059
_runtime,142.0
_timestamp,1622980662.0
_step,0.0


0,1
epoch,▁
trn/loss,▁
trn/acc,▁
trn/f1,▁
val/loss,▁
val/acc,▁
val/f1,▁
_runtime,▁
_timestamp,▁
_step,▁
