In [1]:
import os
from pathlib import Path
import pandas as pd
from ml_collections import config_dict

import wandb
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback

In [2]:
cfg = config_dict.ConfigDict()
cfg.img_size = 256
cfg.target_column = 'mold'
cfg.bs = 32
cfg.seed = 42
cfg.epochs = 2
cfg.lr = 2e-3
cfg.arch = 'resnet18'
cfg.log_model = True
cfg.PROJECT_NAME = 'lemon-project'
cfg.ENTITY = 'wandb_course'
cfg.PROCESSED_DATA_AT = 'lemon_dataset_split_data:latest'

set_seed(cfg.seed, reproducible=True)

In [3]:
def prepare_data(PROCESSED_DATA_AT):
    "Get/Download the datasets"
    processed_data_at = wandb.use_artifact(PROCESSED_DATA_AT)
    processed_dataset_dir = Path(processed_data_at.download())
    df = pd.read_csv(processed_dataset_dir / 'data_split.csv')
    df = df[df.stage != 'test'].reset_index(drop=True)
    df['valid'] = df.stage == 'valid'
    return df, processed_dataset_dir

A good way of making experimetns reproducible, sets numpy, torch, random and cudnn seeds

In [4]:
def get_dataloaders(df, path, seed, target_column, img_size, bs):
    "Get train/valid PyTorch Dataloaders"
    dls = ImageDataLoaders.from_df(df, path=path, seed=seed, fn_col='file_name', label_col=target_column, 
                                   valid_col='valid', item_tfms=Resize(img_size), bs=bs)
    return dls

In [5]:
def log_predictions(learn):
    "Log a wandb.Table with model predictions on the validation dataset"
    inp,preds,targs,out = learn.get_preds(with_input=True, with_decoded=True)
    imgs = [wandb.Image(t.permute(1,2,0)) for t in inp]
    pred_proba = preds[:,1].numpy().tolist()
    targets = targs.numpy().tolist()
    predictions = out.numpy().tolist()
    df = pd.DataFrame(list(zip(imgs, pred_proba, predictions, targets)),
               columns =['image', 'probability', 'prediction', 'target'])
    wandb.log({'predictions_table': wandb.Table(dataframe=df)})

In [6]:
def train(cfg):
    set_seed(cfg.seed)
    with wandb.init(project=cfg.PROJECT_NAME, entity=cfg.ENTITY, job_type="training", config=cfg.to_dict()):
        cfg = wandb.config
        df, path = prepare_data(cfg.PROCESSED_DATA_AT)
        dls = get_dataloaders(df, path, cfg.seed, cfg.target_column, cfg.img_size, cfg.bs)
        learn = vision_learner(dls, 
                               cfg.arch,
                               metrics=[accuracy, Precision(), Recall(), F1Score()],
                               cbs=[WandbCallback(log_preds=False, log_model=cfg.log_model), 
                                    SaveModelCallback(fname=cfg.arch, monitor='f1_score')])
        learn.fine_tune(cfg.epochs)  
        if cfg.log_model:
            log_predictions(learn)

Let's check it works by re-running the baseline

In [7]:
train(cfg)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m ([33mwandb_course[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact lemon_dataset_split_data:latest, 137.87MB. 2695 files... Done. 0:0:0.3


epoch,train_loss,valid_loss,accuracy,precision_score,recall_score,f1_score,time
0,0.711309,0.346772,0.909524,0.611111,0.814815,0.698413,00:14


Better model found at epoch 0 with f1_score value: 0.6984126984126984.


epoch,train_loss,valid_loss,accuracy,precision_score,recall_score,f1_score,time
0,0.241471,0.325147,0.957143,0.78125,0.925926,0.847458,00:18
1,0.136446,0.260904,0.957143,0.821429,0.851852,0.836364,00:18


Better model found at epoch 0 with f1_score value: 0.847457627118644.


VBox(children=(Label(value='89.484 MB of 89.484 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
accuracy,▁██
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1_score,▁█▇
lr_0,▁▁▂▂▃▃▄▅▆▇▇██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_1,▁▁▂▂▃▃▄▅▆▇▇██▂▂▂▃▃▄▄▄▅▅▄▄▄▄▄▃▃▃▃▂▂▂▁▁▁▁▁
mom_0,███▇▇▆▅▄▃▂▂▁▁██▇▆▅▃▂▁▁▁▁▁▂▂▃▃▄▄▅▆▆▇▇▇███
mom_1,███▇▇▆▅▄▃▂▂▁▁██▇▆▅▃▂▁▁▁▁▁▂▂▃▃▄▄▅▆▆▇▇▇███
precision_score,▁▇█

0,1
accuracy,0.95714
epoch,3.0
eps_0,1e-05
eps_1,1e-05
f1_score,0.83636
lr_0,0.0
lr_1,0.0
mom_0,0.94998
mom_1,0.94998
precision_score,0.82143
