In [1]:
import wandb
import pandas as pd
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback

import params
from utils import get_predictions, create_iou_table, MIOU, BackgroundIOU, \
                  RoadIOU, TrafficLightIOU, TrafficSignIOU, PersonIOU, VehicleIOU, BicycleIOU

In [2]:
set_seed(42, reproducible=True)

In [3]:
train_config = SimpleNamespace(
    framework="fastai",
    img_size=(180*2, 320*2),
    batch_size=8, #8 keep small in Colab to be manageable
    augment=True, # use data augmentation
    epochs=10, # for brevity, increase for better results :)
    lr=2e-3,
    pretrained=True  # whether to use pretrained encoder,
)

In [4]:
run = wandb.init(project=params.WANDB_PROJECT, entity=params.ENTITY, job_type="training", config=train_config)

[34m[1mwandb[0m: Currently logged in as: [33mdarek[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
processed_data_at = run.use_artifact(f'{params.PROCESSED_DATA_AT}:latest')
processed_dataset_dir = Path(processed_data_at.download())
df = pd.read_csv(processed_dataset_dir / 'data_split.csv')

[34m[1mwandb[0m: Downloading large artifact bdd_simple_1k_split:latest, 813.25MB. 4010 files... Done. 0:0:0.5


In [6]:
df = df[df.Stage != 'test'].reset_index(drop=True)
df['is_valid'] = df.Stage == 'valid'

In [7]:
def label_func(fname):
    return (fname.parent.parent/"labels")/f"{fname.stem}_mask.png"

In [8]:
# assign paths
df["image_fname"] = [processed_dataset_dir/f'images/{f}' for f in df.File_Name.values]
df["label_fname"] = [label_func(f) for f in df.image_fname.values]

In [9]:
def get_data(df, bs=4, img_size=(180, 320), augment=True):
    block = DataBlock(blocks=(ImageBlock, MaskBlock(codes=params.BDD_CLASSES)),
                  get_x=ColReader("image_fname"),
                  get_y=ColReader("label_fname"),
                  splitter=ColSplitter(),
                  item_tfms=Resize(img_size),
                  batch_tfms=aug_transforms() if augment else None,
                 )
    return block.dataloaders(df, bs=bs)

In [10]:
config = wandb.config

In [11]:
dls = get_data(df, bs=config.batch_size, img_size=config.img_size, augment=config.augment)

  ret = func(*args, **kwargs)


In [12]:
metrics = [MIOU(), BackgroundIOU(), RoadIOU(), TrafficLightIOU(), \
           TrafficSignIOU(), PersonIOU(), VehicleIOU(), BicycleIOU()]

learn = unet_learner(dls, arch=resnet18, pretrained=config.pretrained, metrics=metrics, cbs=SaveModelCallback())#.to_fp16()

In [13]:
learn.fit_one_cycle(config.epochs, config.lr, cbs=[WandbCallback(log_preds=False, log_model=True)])

epoch,train_loss,valid_loss,miou,background_iou,road_iou,traffic_light_iou,traffic_sign_iou,person_iou,vehicle_iou,bicycle_iou,time
0,0.46227,0.387875,0.322154,0.86122,0.727246,0.0,0.0,0.0,0.666611,0.0,03:12
1,0.414273,0.474491,0.325923,0.864821,0.72922,0.0,0.0,0.0,0.687418,0.0,03:09
2,0.335936,0.282546,0.344916,0.902314,0.808683,0.0,0.0,0.0,0.703413,0.0,03:06
3,0.283595,0.236887,0.35688,0.914699,0.82239,0.0,0.0,0.0,0.761071,0.0,03:05
4,0.270183,0.248005,0.357903,0.917094,0.832349,0.0,0.0,0.0,0.755876,0.0,03:05
5,0.231179,0.229477,0.372141,0.925408,0.840127,0.059531,0.0,0.0,0.779918,0.0,03:04
6,0.210274,0.212178,0.368595,0.929823,0.859233,0.001815,0.0,0.0,0.789294,0.0,03:04
7,0.184017,0.203875,0.397377,0.93159,0.861519,0.042615,0.0,0.150203,0.795714,0.0,03:04
8,0.16974,0.201497,0.420925,0.933232,0.86093,0.135202,0.001635,0.209717,0.805756,0.0,03:04
9,0.159584,0.200046,0.425591,0.9333,0.862763,0.150477,0.005655,0.22136,0.805582,0.0,03:04


Better model found at epoch 0 with valid_loss value: 0.3878748416900635.
Better model found at epoch 2 with valid_loss value: 0.2825457751750946.
Better model found at epoch 3 with valid_loss value: 0.23688741028308868.
Better model found at epoch 5 with valid_loss value: 0.2294769287109375.
Better model found at epoch 6 with valid_loss value: 0.21217767894268036.
Better model found at epoch 7 with valid_loss value: 0.20387493073940277.
Better model found at epoch 8 with valid_loss value: 0.20149725675582886.
Better model found at epoch 9 with valid_loss value: 0.20004601776599884.


In [14]:
samples, outputs, predictions = get_predictions(learn)
table = create_iou_table(samples, outputs, predictions, params.BDD_CLASSES)
wandb.log({"pred_table":table})

In [15]:
wandb.finish()

VBox(children=(Label(value='142.754 MB of 142.763 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.99…

0,1
background_iou,▁▁▅▆▆▇████
bicycle_iou,▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▁▂▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▁▁▁▁▁▁
lr_1,▁▂▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▁▁▁▁▁▁
lr_2,▁▂▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▁▁▁▁▁▁
miou,▁▁▃▃▃▄▄▆██

0,1
background_iou,0.9333
bicycle_iou,0.0
epoch,10.0
eps_0,1e-05
eps_1,1e-05
eps_2,1e-05
lr_0,0.0
lr_1,0.0
lr_2,0.0
miou,0.42559
