In [None]:
import wandb
import pandas as pd
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback
from params import BDD_CLASSES
import params
from utils import get_predictions, create_dice_table

In [None]:
set_seed(42, reproducible=True)

In [None]:
train_config = SimpleNamespace(
    framework="fastai",
    img_size=(45, 80), #(180, 320),
    batch_size=2, #8 keep small in Colab to be manageable
    augment=True, # use data augmentation
    epochs=1, # for brevity, increase for better results :)
    lr=2e-3,
    pretrained=True  # whether to use pretrained encoder,
)

In [None]:
run = wandb.init(project=params.WANDB_PROJECT, entity=params.ENTITY, job_type="training", config=train_config)

[34m[1mwandb[0m: Currently logged in as: [33mdarek[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
processed_data_at = run.use_artifact(f'{params.PROCESSED_DATA_AT}:latest')
processed_dataset_dir = Path(processed_data_at.download())
df = pd.read_csv(processed_dataset_dir / 'data_split.csv')

[34m[1mwandb[0m: Downloading large artifact bdd_sample_1k_split:latest, 856.26MB. 4006 files... 
[34m[1mwandb[0m:   4006 of 4006 files downloaded.  
Done. 0:0:1.0


In [None]:
df = df[df.Stage != 'test'].reset_index(drop=True)
df['is_valid'] = df.Stage == 'valid'

In [None]:
def label_func(fname):
    return (fname.parent.parent/"labels")/f"{fname.stem}_mask.png"

In [None]:
# assign paths
df["image_fname"] = [processed_dataset_dir/f'images/{f}' for f in df.File_Name.values]
df["label_fname"] = [label_func(f) for f in df.image_fname.values]

In [None]:
def get_data(df, bs=4, img_size=(180, 320), augment=True):
    block = DataBlock(blocks=(ImageBlock, MaskBlock(codes=BDD_CLASSES)),
                  get_x=ColReader("image_fname"),
                  get_y=ColReader("label_fname"),
                  splitter=ColSplitter(),
                  item_tfms=Resize(img_size),
                  batch_tfms=aug_transforms() if augment else None,
                 )
    return block.dataloaders(df, bs=bs)

In [None]:
config = wandb.config

In [None]:
dls = get_data(df, bs=config.batch_size, img_size=config.img_size, augment=config.augment)

In [None]:
learn = unet_learner(dls, arch=resnet18, pretrained=config.pretrained, metrics=[foreground_acc, DiceMulti()], cbs=SaveModelCallback())#.to_fp16()



In [None]:
learn.fit_one_cycle(config.epochs, config.lr, cbs=[WandbCallback(log_preds=False, log_model=True)])

epoch,train_loss,valid_loss,foreground_acc,dice_multi,time
0,1.098883,1.070148,0.609322,0.188674,00:49


Better model found at epoch 0 with valid_loss value: 1.0701483488082886.


In [None]:
samples, outputs, predictions = get_predictions(learn)
table = create_dice_table(samples, outputs, predictions, BDD_CLASSES)
wandb.log({"pred_table":table})

In [None]:
wandb.finish()

0,1
dice_multi,▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
foreground_acc,▁
lr_0,▁▁▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
lr_1,▁▁▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
lr_2,▁▁▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
mom_0,██▇▆▅▄▃▂▁▁▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇█████

0,1
dice_multi,0.18867
epoch,1.0
eps_0,1e-05
eps_1,1e-05
eps_2,1e-05
foreground_acc,0.60932
lr_0,0.0
lr_1,0.0
lr_2,0.0
mom_0,0.95
