In [1]:
from torchvision.models import resnet34
from pathlib import Path
import pickle
import torch, torch.nn as nn
import os
from PIL import Image
import numpy as np
import cv2
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as TF
from datetime import datetime as dt
import sklearn.metrics
import wandb
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from enum import Enum

In [2]:
def read_pickle(fname):
    with open(lbl_folder/f'{fname}.pkl', 'rb') as f: 
        _ = pickle.load(f)
    return _

In [3]:
def load_img(pt): 
    img = np.array(Image.open(str(pt)))
    img = cv2.resize(img, (224,224))
    return img

In [4]:
lbl_folder = Path('_data_lbls')
NUM_DISEASES = len(read_pickle('mapping_diseases'))
NUM_MORPH = len(read_pickle('mapping_morph'))

In [5]:
NUM_MORPH

23

In [6]:
os.listdir(lbl_folder)

['mapping_diseases.pkl',
 'mapping_morph.pkl',
 'gsa.pkl',
 'atlas_derm.pkl',
 'hellenic.pkl',
 'dermnetnz.pkl',
 'ulb.pkl']

In [7]:
datasets = ['gsa', 'ulb', 'atlas_derm', 'hellenic', 'dermnetnz']

In [8]:
lbls = {ds: read_pickle(ds) for ds in datasets}

In [9]:
idxs = [(ds, i) for ds in datasets for i in range(len(lbls[ds]))]

In [10]:
len(idxs)

7682

In [11]:
diseases = {(ds, i): lbls[ds][i][2] for (ds, i) in idxs}

In [12]:
imgs = {}
for ds in datasets:
    els = lbls[ds]
    ds_imgs = [load_img(el[1]) for el in tqdm(els, desc=ds)]
    imgs[ds] = ds_imgs

gsa:   0%|          | 0/1969 [00:00<?, ?it/s]



ulb:   0%|          | 0/319 [00:00<?, ?it/s]

atlas_derm:   0%|          | 0/2740 [00:00<?, ?it/s]

hellenic:   0%|          | 0/944 [00:00<?, ?it/s]

dermnetnz:   0%|          | 0/1710 [00:00<?, ?it/s]

In [13]:
imgs2 = {(ds, i): imgs[ds][i] for (ds, i) in idxs}

In [14]:
normalize = TF.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
to_tensor = TF.ToTensor()

class DiseaseDataset(Dataset):
    def __init__(self, idxs):
        self.idxs = idxs
    def __len__(self): return len(self.idxs)
    def __getitem__(self, i):
        idx = self.idxs[i]
        img, lbl = imgs2[idx], diseases[idx]
        img = to_tensor(img)
        img = normalize(img)
        return img, lbl

In [15]:
def mean(L:list): return sum(L)/len(L)

class Mode(Enum):
    Train = 'train'
    Eval = 'eval'
    
def _np(t): return t.detach().cpu().numpy()
def _it(t): return _np(t).item()

def step(dl, mode, phase):
    if mode == Mode.Train: model.train()
    else: model.eval()
    
    losses, preds, targs = [], [], []
    for b in tqdm(dl, leave=False, desc=phase):
        imgs, lbls = b[0].to(device), b[1].to(device)
        if mode == Mode.Train:
            optimizer.zero_grad()
            out = model(imgs)
            loss = loss_fn(out, lbls)
            loss.backward()
            optimizer.step()
        else:
            with torch.no_grad():
                out = model(imgs)
                loss = loss_fn(out, lbls)
        losses.append(_it(loss))
        preds.append(_np(out))
        targs.append(_np(lbls))
    preds = np.concatenate(preds, axis=0)
    targs = np.concatenate(targs, axis=0)
    return losses, preds, targs

def compute_metrics(phase, losses, preds, targs):
    preds = np.argmax(preds, axis=1)
    return {f'{phase}/loss': mean(losses),
         f'{phase}/acc': sklearn.metrics.accuracy_score(targs, preds),
         f'{phase}/f1': sklearn.metrics.f1_score(targs, preds, average='macro')}

def append_dict(D1, D2):
    if len(set(D1.keys()).intersection(set(D2.keys()))) > 0:
        raise Exception("common keys")
    D = {}
    for k,v in D1.items(): D[k] = v
    for k,v in D2.items(): D[k] = v
    return D

In [16]:
device = torch.device('cuda:3')

In [17]:
model = resnet34(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, NUM_DISEASES, bias=(model.fc.bias is None))
model = model.to(device)

In [18]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [19]:
loss_fn = nn.CrossEntropyLoss()

In [20]:
torch.set_num_threads(2)

In [21]:
run = wandb.init(project='derm-dis-morph', name='baseline')

[34m[1mwandb[0m: Currently logged in as: [33mtanyapole[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.31 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [22]:
trn_idxs, val_idxs = train_test_split(idxs, train_size=0.8, stratify=[diseases[idx] for idx in idxs])

In [23]:
trn_ds = DiseaseDataset(trn_idxs)
val_ds = DiseaseDataset(val_idxs)

In [24]:
trn_dl = DataLoader(trn_ds, batch_size=30)
val_dl = DataLoader(val_ds, batch_size=30)

In [None]:
for epoch in tqdm(list(range(10000)), desc='Epoch'):
    D = {'epoch': epoch}
    
    losses, preds, targs = step(trn_dl, Mode.Train, 'Train')
    D = append_dict(D, compute_metrics('trn', losses, preds, targs))
    
    losses, preds, targs = step(val_dl, Mode.Eval, 'Valid')
    D = append_dict(D, compute_metrics('val', losses, preds, targs))
    
    wandb.log(D)

Epoch:   0%|          | 0/10000 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

Valid:   0%|          | 0/52 [00:00<?, ?it/s]

Train:   0%|          | 0/205 [00:00<?, ?it/s]

In [None]:
run.finish();