In [None]:
import wandb
import pandas as pd
import torchvision.models as tvmodels
import torch
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback

import params
from utils import get_predictions, create_iou_table, MIOU, BackgroundIOU, \
                  RoadIOU, TrafficLightIOU, TrafficSignIOU, PersonIOU, VehicleIOU, BicycleIOU, t_or_f, display_diagnostics

In [None]:
train_config = SimpleNamespace(
    framework="fastai",
    img_size=180, #(180, 320) in 16:9 proportions,
    batch_size=8, #8 keep small in Colab to be manageable
    augment=True, # use data augmentation
    epochs=1, # for brevity, increase for better results :)
    lr=2e-3,
    pretrained=True,  # whether to use pretrained encoder,
    mixed_precision=True, # use automatic mixed precision
    arch="resnet18",
    seed=42,
    log_preds=False,
)

In [None]:
def download_data():
    processed_data_at = wandb.use_artifact(f'{params.PROCESSED_DATA_AT}:latest')
    processed_dataset_dir = Path(processed_data_at.download())
    return processed_dataset_dir

In [None]:
def label_func(fname):
    return (fname.parent.parent/"labels")/f"{fname.stem}_mask.png"

In [None]:
def get_df(processed_dataset_dir, is_test=False):
    df = pd.read_csv(processed_dataset_dir / 'data_split.csv')
    
    if not is_test:
        df = df[df.Stage != 'test'].reset_index(drop=True)
        df['is_valid'] = df.Stage == 'valid'
    else:
        df = df[df.Stage == 'test'].reset_index(drop=True)
        
    
    # assign paths
    df["image_fname"] = [processed_dataset_dir/f'images/{f}.jpg' for f in df.File_Name.values]
    df["label_fname"] = [label_func(f) for f in df.image_fname.values]
    return df

In [None]:
def get_data(df, bs=4, img_size=(180, 320), augment=True):
    block = DataBlock(blocks=(ImageBlock, MaskBlock(codes=params.BDD_CLASSES)),
                  get_x=ColReader("image_fname"),
                  get_y=ColReader("label_fname"),
                  splitter=ColSplitter(),
                  item_tfms=Resize(img_size),
                  batch_tfms=aug_transforms() if augment else None,
                 )
    return block.dataloaders(df, bs=bs)

In [None]:
def log_predictions(learn):
    "Log a Table with model predictions"
    samples, outputs, predictions = get_predictions(learn)
    table = create_iou_table(samples, outputs, predictions, params.BDD_CLASSES)
    wandb.log({"pred_table":table})

In [None]:
def final_metrics(learn):
    scores = learn.validate()
    metric_names = ['final_loss'] + [f'final_{x.name}' for x in learn.metrics]
    final_results = {metric_names[i] : scores[i] for i in range(len(scores))}
    for k,v in final_results.items(): 
        wandb.summary[k] = v

In [None]:
def save_learner(learn, run):
    art = wandb.Artifact('learner', type="fastai learner")
    with art.new_file('fastai_model.pkl') as f:
        learn.export(f.name)
    run.log_artifact(art)
        
def save_dls(dls, run, nm='train'):
    torch.save(dls, f'{nm}-dataloader.pkl')
    art = wandb.Artifact(f'{nm}_dls', type="fastai dataloaders")
    art.add_file(f'{nm}-dataloader.pkl')
    run.log_artifact(art)

    
def train(config):
    set_seed(config.seed)
    with wandb.init(project=params.WANDB_PROJECT, entity=params.ENTITY, job_type="training", config=config) as run:
        
        # good practice to inject params using sweeps
        config = wandb.config
        
        # prepare data
        processed_dataset_dir = download_data()
        proc_df = get_df(processed_dataset_dir)
        dls = get_data(proc_df, bs=config.batch_size, img_size=config.img_size, augment=config.augment)
        save_dls(dls, run)
        
        metrics = [MIOU(), BackgroundIOU(), RoadIOU(), TrafficLightIOU(),
                   TrafficSignIOU(), PersonIOU(), VehicleIOU(), BicycleIOU()]
        
        cbs = [WandbCallback(log_preds=False, log_model=True), 
               SaveModelCallback(monitor='miou'),] + ([MixedPrecision()] if config.mixed_precision else [])
        
        learn = unet_learner(dls, arch=getattr(tvmodels, config.arch), pretrained=config.pretrained, 
                             metrics=metrics)
        
        learn.fit_one_cycle(config.epochs, config.lr, cbs=cbs)
        if config.log_preds:
            log_predictions(learn)
        final_metrics(learn)
        _, disp = display_diagnostics(learner=learn, return_vals=True)
        wandb.log({"confusion matrix": disp.figure_})
        save_learner(learn, run)

## Run the training

In [None]:
train(train_config)