In [1]:
import wandb
import pandas as pd
import torchvision.models as tvmodels
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback

import params
from utils import (
    get_predictions,
    create_iou_table,
    MIOU,
    BackgroundIOU,
    RoadIOU,
    TrafficLightIOU,
    TrafficSignIOU,
    PersonIOU,
    VehicleIOU,
    BicycleIOU,
)

In [2]:
train_config = SimpleNamespace(
    framework="fastai",
    img_size=(180, 320),
    batch_size=8,
    augment=True,  # use data augmentation
    epochs=10,
    lr=2e-3,
    arch="resnet18",
    pretrained=True,  # whether to use pretrained encoder
    seed=42,
    log_preds=True,
)

In [3]:
def download_data():
    processed_data_at = wandb.use_artifact(f"{params.PROCESSED_DATA_AT}:latest")
    processed_dataset_dir = Path(processed_data_at.download())
    return processed_dataset_dir

In [4]:
def label_func(fname):
    return (fname.parent.parent / "labels") / f"{fname.stem}_mask.png"

In [5]:
def get_df(processed_dataset_dir, is_test=False):
    df = pd.read_csv(processed_dataset_dir / "data_split.csv")

    if not is_test:
        df = df[df.Stage != "test"].reset_index(drop=True)
        df["is_valid"] = df.Stage == "valid"
    else:
        df = df[df.Stage == "test"].reset_index(drop=True)

    # assign paths
    df["image_fname"] = [
        processed_dataset_dir / f"images/{f}" for f in df.File_Name.values
    ]
    df["label_fname"] = [label_func(f) for f in df.image_fname.values]
    return df

In [6]:
def get_data(df, bs=4, img_size=(180, 320), augment=True):
    block = DataBlock(
        blocks=(ImageBlock, MaskBlock(codes=params.BDD_CLASSES)),
        get_x=ColReader("image_fname"),
        get_y=ColReader("label_fname"),
        splitter=ColSplitter(),
        item_tfms=Resize(img_size),
        batch_tfms=aug_transforms() if augment else None,
    )
    return block.dataloaders(df, bs=bs)

In [7]:
def log_predictions(learn):
    "Log a Table with model predictions"
    samples, outputs, predictions = get_predictions(learn)
    table = create_iou_table(samples, outputs, predictions, params.BDD_CLASSES)
    wandb.log({"pred_table": table})

In [8]:
def log_final_metrics(learn):
    scores = learn.validate()
    metric_names = ["final_loss"] + [f"final_{x.name}" for x in learn.metrics]
    final_results = {metric_names[i]: scores[i] for i in range(len(scores))}
    for k, v in final_results.items():
        wandb.summary[k] = v

In [9]:
def train(config):
    set_seed(config.seed, reproducible=True)
    run = wandb.init(
        project=params.WANDB_PROJECT,
        entity=params.ENTITY,
        job_type="training",
        config=config,
    )

    config = wandb.config

    processed_dataset_dir = download_data()
    df = get_df(processed_dataset_dir)

    dls = get_data(
        df, bs=config.batch_size, img_size=config.img_size, augment=config.augment
    )

    metrics = [
        MIOU(),
        BackgroundIOU(),
        RoadIOU(),
        TrafficLightIOU(),
        TrafficSignIOU(),
        PersonIOU(),
        VehicleIOU(),
        BicycleIOU(),
    ]

    learn = unet_learner(
        dls,
        arch=getattr(tvmodels, config.arch),
        pretrained=config.pretrained,
        metrics=metrics,
    )

    callbacks = [
        SaveModelCallback(monitor="miou"),
        WandbCallback(log_preds=False, log_model=True),
    ]

    learn.fit_one_cycle(config.epochs, config.lr, cbs=callbacks)

    if config.log_preds:
        log_predictions(learn)

    log_final_metrics(learn)

    wandb.finish()

In [10]:
train(train_config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreynald-havard[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact bdd_simple_1k_split:latest, 846.07MB. 4010 files... 
[34m[1mwandb[0m:   4010 of 4010 files downloaded.  
Done. 0:0:14.6


epoch,train_loss,valid_loss,miou,background_iou,road_iou,traffic_light_iou,traffic_sign_iou,person_iou,vehicle_iou,bicycle_iou,time
0,0.470066,0.300414,0.323163,0.874089,0.759075,0.0,0.0,0.0,0.628977,0.0,00:45
1,0.399541,0.337993,0.30505,0.87309,0.749559,0.0,0.0,0.0,0.512698,0.0,00:48
2,0.361387,0.319983,0.325481,0.87558,0.75618,0.0,0.0,0.0,0.646611,0.0,00:50
3,0.31468,0.304456,0.327561,0.885774,0.754995,0.0,0.0,0.0,0.652155,0.0,01:00
4,0.281292,0.272417,0.342014,0.900364,0.820419,0.0,0.0,0.0,0.673316,0.0,00:51
5,0.256844,0.242138,0.352813,0.911518,0.829185,0.0,0.0,0.0,0.728989,0.0,00:54
6,0.230782,0.238664,0.357606,0.910491,0.831318,0.030359,0.0,0.0,0.731072,0.0,00:52
7,0.207542,0.245496,0.369571,0.909231,0.818452,0.125835,0.0,0.0,0.733481,0.0,00:48
8,0.190269,0.233762,0.375046,0.92086,0.843007,0.101474,0.001072,0.0,0.75891,0.0,00:52
9,0.176146,0.228436,0.380274,0.920715,0.844428,0.134212,0.001408,0.0,0.761157,0.0,00:55


Better model found at epoch 0 with miou value: 0.3231629872756783.
Better model found at epoch 2 with miou value: 0.32548146349889046.
Better model found at epoch 3 with miou value: 0.3275605775453746.
Better model found at epoch 4 with miou value: 0.34201422847302654.
Better model found at epoch 5 with miou value: 0.35281312693754285.
Better model found at epoch 6 with miou value: 0.3576057822658045.
Better model found at epoch 7 with miou value: 0.36957129881631473.
Better model found at epoch 8 with miou value: 0.375046092644009.
Better model found at epoch 9 with miou value: 0.3802742171429146.


  state = torch.load(file, map_location=device, **torch_load_kwargs)


0,1
background_iou,▁▁▁▃▅▇▆▆██
bicycle_iou,▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▁▂▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▁▁▁▁▁▁
lr_1,▁▂▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▁▁▁▁▁▁
lr_2,▁▂▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▁▁▁▁▁▁
miou,▃▁▃▃▄▅▆▇██

0,1
background_iou,0.92072
bicycle_iou,0.0
epoch,10.0
eps_0,1e-05
eps_1,1e-05
eps_2,1e-05
final_background_iou,0.92072
final_bicycle_iou,0.0
final_loss,0.22844
final_miou,0.38027
