In [None]:
import os
import pandas as pd
from ml_collections import config_dict
from fastai.vision.all import Path, ImageDataLoaders, vision_learner, \
    accuracy, Precision, Recall, F1Score, SaveModelCallback, resnet18, Resize
from fastai.callback.wandb import WandbCallback
import timm

import wandb

In [None]:
cfg = config_dict.ConfigDict()
cfg.img_size = 64
cfg.target_column = 'mold'
cfg.bs = 128
cfg.seed = 42
cfg.epochs = 2
cfg.PROJECT_NAME = 'lemon-dataset'
cfg.RAW_DATA_AT = 'lemon_dataset_raw_data_2690'
cfg.PROCESSED_DATA_AT = 'lemon_dataset_processed_data'
cfg.arch = 'resnet18'

In [None]:
# functions
def prepare_data(cfg, run):
    raw_data_at = run.use_artifact(f'{cfg.RAW_DATA_AT}:latest')
    raw_dataset_dir = raw_data_at.download()
    processed_data_at = run.use_artifact(f'{cfg.PROCESSED_DATA_AT}:latest')
    processed_dataset_dir = processed_data_at.download()
    df = pd.read_csv(os.path.join(processed_dataset_dir, 'data_split.csv'))
    df = df[df.stage != 'test'].reset_index(drop=True)
    df['valid'] = df.stage == 'valid'
    path = Path(raw_dataset_dir)
    return df, path


In [None]:
def main(cfg):
    run = wandb.init(project=cfg.PROJECT_NAME, job_type="training", tags='fastai')
    wandb.config.update(cfg)
    df, path = prepare_data(cfg, run)
    dls = ImageDataLoaders.from_df(df, path=path, 
                                   seed=cfg.seed, 
                                   fn_col='file_name', 
                                   label_col=cfg.target_column, 
                                   valid_col='valid', 
                                   item_tfms=Resize(cfg.img_size), 
                                   bs=cfg.bs
                                  )
    learn = vision_learner(dls, 
                           cfg.arch,
                           metrics=[accuracy, Precision(), Recall(), F1Score()],
                           cbs=[WandbCallback(log_preds=False), SaveModelCallback(monitor='f1_score')])
    learn.fine_tune(cfg.epochs)   
    inp,preds,targs,out = learn.get_preds(with_input=True, with_decoded=True)
    imgs = [wandb.Image(t.permute(1,2,0)) for t in inp]
    pred_proba = preds[:,1].numpy().tolist()
    targets = targs.numpy().tolist()
    predictions = out.numpy().tolist()
    df = pd.DataFrame(list(zip(imgs, pred_proba, predictions, targets)),
               columns =['image', 'probability', 'prediction', 'target'])
    run.log({'predictions_table': wandb.Table(dataframe=df)})
    run.finish()
    

In [None]:
main(cfg)