In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import wandb
import torch
from fastai.vision.all import *
from typing import List, Union, Tuple
from fastai.callback.wandb import WandbCallback

from segmentation.model import SegmentationModel
from segmentation.train_utils import benchmark_inference_time, save_model_to_artifacts

In [None]:
PROJECT="CamVid"
ENTITY="av-demo"
IMAGE_SHAPE = (720, 960)
SEED = 123
RUN_NAME = "baseline-train-1"
JOB_TYPE = "baseline-train"

ARTIFACT_ID = "av-demo/CamVid/camvid-dataset:v0"

BATCH_SIZE = 16
IMAGE_RESIZE_FACTOR = 4
VALIDATION_SPLIT = 0.2
HIDDEN_DIM = 256
BACKBONE = "mobilenetv2_100"

LEARNING_RATE = 1e-3
TRAIN_EPOCHS = 10

INFERENCE_BATCH_SIZE = 8
NUM_WARMUP_ITERS = 10
NUM_INFERENCE_BENCHMARK_ITERS = 50

In [None]:
set_seed(SEED)

In [None]:
run = wandb.init(
    project=PROJECT,
    name=RUN_NAME,
    entity=ENTITY,
    job_type=JOB_TYPE,
    config={
        "artifact_id": ARTIFACT_ID,
        "image_shape": IMAGE_SHAPE,
        "batch_size": BATCH_SIZE,
        "image_resize_factor": IMAGE_RESIZE_FACTOR,
        "validation_split": VALIDATION_SPLIT,
        "hidden_dims": HIDDEN_DIM,
        "backbone": BACKBONE,
        "learning_rate": LEARNING_RATE,
        "train_epochs": TRAIN_EPOCHS,
        "inference_batch_size": INFERENCE_BATCH_SIZE,
        "num_warmup_iters": NUM_WARMUP_ITERS,
        "num_inference_banchmark_iters": NUM_INFERENCE_BENCHMARK_ITERS
    }
)

## DataLoader for SegmentationDataLoader for Segmentation

In [None]:
def label_func(fn):
    return fn.parent.parent/"labels"/f"{fn.stem}_P{fn.suffix}"


def get_dataloader(
    artifact_id: str,
    batch_size: int,
    image_shape: Tuple[int, int],
    resize_factor: int,
    validation_split: float,
    seed: int
):
    """Grab an artifact and creating a Pytorch DataLoader"""
    artifact = wandb.use_artifact(artifact_id, type='dataset')
    artifact_dir = Path(artifact.download())
    codes = np.loadtxt(artifact_dir/'codes.txt', dtype=str)
    fnames = get_image_files(artifact_dir/"images")
    class_labels = {k: v for k, v in enumerate(codes)}
    return SegmentationDataLoaders.from_label_func(
        artifact_dir,
        bs=batch_size,
        fnames=fnames,
        label_func=label_func,
        codes=codes,
        item_tfms=Resize((
            image_shape[0] // resize_factor,
            image_shape[1] // resize_factor
        )),
        valid_pct=validation_split,
        seed=seed
    ), class_labels

In [None]:
data_loader, class_labels = get_dataloader(
    artifact_id=ARTIFACT_ID,
    batch_size=BATCH_SIZE,
    image_shape=IMAGE_SHAPE,
    resize_factor=IMAGE_RESIZE_FACTOR,
    validation_split=VALIDATION_SPLIT,
    seed=SEED
)

data_loader.show_batch(max_n=4, vmin=1, vmax=30, figsize=(14, 10))

## Training and Inference

In [None]:
def get_model_parameters(model):
    with torch.no_grad():
        num_params = sum(p.numel() for p in model.parameters())
    return num_params


def get_predictions(learner):
    inputs, predictions, targets, outputs = learner.get_preds(with_input=True, with_decoded=True)
    x, y, samples, outputs = learner.dls.valid.show_results(
        tuplify(inputs) + tuplify(targets), outputs, show=False, max_n=36
    )
    return samples, outputs, predictions


def create_wandb_table(samples, outputs, class_labels):
    "Creates a wandb table with predictions and targets side by side"
    table = wandb.Table(columns=["Image", "Predicted Mask", "Ground Truth"])
    for (image, label), pred_label in zip(samples, outputs):
        image = image.permute(1, 2, 0)
        table.add_data(
            wandb.Image(image),
            wandb.Image(
                image,
                masks={
                    "predictions":  {
                        'mask_data':  pred_label[0].numpy(),
                        'class_labels':class_labels
                    }
                }
            ),
            wandb.Image(
                image,
                masks={
                    "ground truths": {
                        'mask_data': label.numpy(),
                        'class_labels':class_labels
                    }
                }
            )
        )
    return table

In [None]:
def get_learner(
    data_loader,
    backbone: str,
    hidden_dim: int,
    num_classes: int,
    checkpoint_file: Union[None, str, Path],
    loss_func,
    metrics: List,
    log_preds: bool = False
):
    model = SegmentationModel(backbone, hidden_dim, num_classes=num_classes)
    save_model_callback = SaveModelCallback(fname=f"unet_{backbone}")
    mixed_precision_callback = MixedPrecision()
    wandb_callback = WandbCallback(log_preds=log_preds)
    learner = Learner(
        data_loader,
        model,
        loss_func=loss_func,
        metrics=metrics,
        cbs=[save_model_callback, mixed_precision_callback, wandb_callback],
    )
    if checkpoint_file is not None:
        learner.load(checkpoint_file)
    return learner

In [None]:
learner = get_learner(
    data_loader,
    backbone=BACKBONE,
    hidden_dim=HIDDEN_DIM,
    num_classes=len(class_labels),
    checkpoint_file=None,
    loss_func=FocalLossFlat(axis=1),
    metrics=[DiceMulti(), foreground_acc],
    log_preds=True
)

In [None]:
# learner.fine_tune(10, 1e-3)
learner.fit_one_cycle(TRAIN_EPOCHS, LEARNING_RATE)

In [None]:
learner.show_results(max_n=4, figsize=(14, 10))

In [None]:
samples, outputs, _ = get_predictions(learner)
table = create_wandb_table(samples, outputs, class_labels)
wandb.log({f"Baseline_Predictions_{run.name}": table})

In [None]:
model = learner.model.eval()
torch.cuda.empty_cache()
wandb.log({"Model_Parameters": get_model_parameters(model)})
wandb.log({
    "Inference_Time": benchmark_inference_time(
        model=model,
        batch_size=INFERENCE_BATCH_SIZE,
        image_shape=IMAGE_SHAPE,
        num_warmup_iters=NUM_WARMUP_ITERS,
        num_iter=NUM_INFERENCE_BENCHMARK_ITERS,
        seed=SEED
    )
})

In [None]:
save_model_to_artifacts(
    model=model,
    model_name=f"unet_{BACKBONE}",
    image_shape=IMAGE_SHAPE,
    artifact_name=f"{run.name}-saved-model",
    metadata={
        "backbone": BACKBONE,
        "hidden_dims": HIDDEN_DIM,
        "input_size": IMAGE_SHAPE,
        "class_labels": class_labels
    }
)

In [None]:
run.finish()