In [None]:
import wandb
import pandas as pd
import torchvision.models as tvmodels
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback

from params import BDD_CLASSES
from utils import get_predictions, create_dice_table

We bring those parameters into the config (this is optional) but helps keep everything organised!

In [None]:
default_config = SimpleNamespace(
    WANDB_PROJECT="BDD100k",
    ENTITY = None, # wandb team
    RAW_DATA_AT = 'bdd_sample_1k',
    PROCESSED_DATA_AT = 'bdd_sample_1k_split',
    framework="fastai",
    img_size=45, #(45, 80) in 16:9 proportions,
    batch_size=2, #8 keep small in Colab to be manageable
    augment=True, # use data augmentation
    epochs=1, # for brevity, increase for better results :)
    lr=2e-3,
    pretrained=True,  # whether to use pretrained encoder,
    mixed_precision=False, # use automatic mixed precision
    arch="resnet18",
    seed=42,
    log_preds=True,
)
set_seed(42, reproducible=True)

In [None]:
default_config

namespace(WANDB_PROJECT='BDD100k',
          ENTITY=None,
          RAW_DATA_AT='bdd_sample_1k',
          PROCESSED_DATA_AT='bdd_sample_1k_split',
          framework='fastai',
          img_size=(45, 80),
          batch_size=2,
          augment=True,
          epochs=1,
          lr=0.002,
          pretrained=True,
          mixed_precision=False,
          arch='resnet18',
          seed=1,
          log_preds=True)

In [None]:
run = wandb.init(project=default_config.WANDB_PROJECT, 
                 entity=default_config.ENTITY, 
                 job_type="training", 
                 config=default_config)

[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
def download_dataset(at_name):
    "Grab data from artifact"
    processed_data_at = wandb.run.use_artifact(f'{at_name}:latest')
    return Path(processed_data_at.download())

In [None]:
processed_dataset_dir = download_dataset(default_config.PROCESSED_DATA_AT)

[34m[1mwandb[0m: Downloading large artifact bdd_sample_1k_split:latest, 823.45MB. 4006 files... Done. 0:0:0.2


In [None]:
def label_func(fname):
    return (fname.parent.parent/"labels")/f"{fname.stem}_mask.png"

In [None]:
def prepare_df(processed_dataset_dir, label_func, is_test=False):
    "Set absolute path image names and split"
    df = pd.read_csv(processed_dataset_dir / 'data_split.csv')
    if is_test: 
        # grab the test part of the split
        df = df[df.Stage == 'test'].reset_index(drop=True)
    else:
        df = df[df.Stage != 'test'].reset_index(drop=True)
        df['is_valid'] = df.Stage == 'valid'
    
    # assign paths
    df["image_fname"] = [processed_dataset_dir/f'images/{f}' for f in df.File_Name.values]
    df["label_fname"] = [label_func(f) for f in df.image_fname.values]
    return df

In [None]:
proc_df = prepare_df(processed_dataset_dir, label_func)

In [None]:
proc_df.head()

Unnamed: 0,File_Name,Stage,is_valid,image_fname,label_fname
0,a59131a5-00000000.jpg,train,False,artifacts/bdd_sample_1k_split:v1/images/a59131a5-00000000.jpg,artifacts/bdd_sample_1k_split:v1/labels/a59131a5-00000000_mask.png
1,6886b3d9-6ab2b28d.jpg,train,False,artifacts/bdd_sample_1k_split:v1/images/6886b3d9-6ab2b28d.jpg,artifacts/bdd_sample_1k_split:v1/labels/6886b3d9-6ab2b28d_mask.png
2,115e4aff-00000000.jpg,train,False,artifacts/bdd_sample_1k_split:v1/images/115e4aff-00000000.jpg,artifacts/bdd_sample_1k_split:v1/labels/115e4aff-00000000_mask.png
3,b803d91d-671b8cff.jpg,train,False,artifacts/bdd_sample_1k_split:v1/images/b803d91d-671b8cff.jpg,artifacts/bdd_sample_1k_split:v1/labels/b803d91d-671b8cff_mask.png
4,6b293d3e-59d5f868.jpg,train,False,artifacts/bdd_sample_1k_split:v1/images/6b293d3e-59d5f868.jpg,artifacts/bdd_sample_1k_split:v1/labels/6b293d3e-59d5f868_mask.png


In [None]:
def get_data(df, bs=4, img_size=180, augment=True):
    block = DataBlock(blocks=(ImageBlock, MaskBlock(codes=BDD_CLASSES)),
                  get_x=ColReader("image_fname"),
                  get_y=ColReader("label_fname"),
                  splitter=ColSplitter(),
                  item_tfms=Resize((img_size, int(img_size*16/9)),
                  batch_tfms=aug_transforms() if augment else None,
                 )
    return block.dataloaders(df, bs=bs)

In [None]:
dls = get_data(proc_df)

In [None]:
def log_predictions(learn):
    samples, outputs, predictions = get_predictions(learn)
    table = create_dice_table(samples, outputs, predictions, BDD_CLASSES)
    wandb.log({"pred_table":table})

In [None]:
wandb.finish()

## Putting everything together...

In [None]:
def train(config):
    set_seed(config.seed)
    with wandb.init(project=config.WANDB_PROJECT, entity=config.ENTITY, job_type="training", config=config):
        
        # good practice to inject params using sweeps
        config = wandb.config
        
        # prepare data
        processed_dataset_dir = download_dataset(config.PROCESSED_DATA_AT)
        proc_df = prepare_df(processed_dataset_dir, label_func)
        dls = get_data(proc_df, bs=config.batch_size, img_size=config.img_size, augment=config.augment)
        
        cbs = [SaveModelCallback()] + ([MixedPrecision()] if config.mixed_precision else [])
        
        learn = unet_learner(dls, arch=getattr(tvmodels, config.arch), pretrained=config.pretrained, 
                             metrics=[foreground_acc, DiceMulti()], cbs=cbs)
        
        learn.fit_one_cycle(config.epochs, config.lr, cbs=[WandbCallback(log_preds=False, log_model=True)])
        if config.log_preds:
            log_predictions(learn)

Let's check it works by re-running the baseline

In [None]:
train(default_config)

[34m[1mwandb[0m: Downloading large artifact bdd_sample_1k_split:latest, 823.45MB. 4006 files... Done. 0:0:0.2


epoch,train_loss,valid_loss,foreground_acc,dice_multi,time
0,1.144104,1.069724,0.620211,0.192138,01:17


Better model found at epoch 0 with valid_loss value: 1.0697239637374878.


0,1
dice_multi,▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
foreground_acc,▁
lr_0,▁▁▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
lr_1,▁▁▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
lr_2,▁▁▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
mom_0,██▇▆▅▄▃▂▁▁▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇█████

0,1
dice_multi,0.19214
epoch,1.0
eps_0,1e-05
eps_1,1e-05
eps_2,1e-05
foreground_acc,0.62021
lr_0,0.0
lr_1,0.0
lr_2,0.0
mom_0,0.95
