<a href="https://colab.research.google.com/github/wandb/examples/blob/master/colabs/fastai/Semantic_Segmentation_Demo_with_W&B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<!--- @wandbcode{fastai-semantic-seg} -->

# ❌ This Notebooks is using Fastai V1 that is no longer available on colab

# Semantic Segmentation of Driving Scenes

<!--- @wandbcode{fastai-semantic-seg} -->

This demo shows how to log image masks for a semantic segmentation model. This model version is simpler and smaller, and the training is shorter for demo purposes. Please see [this W&B report](https://app.wandb.ai/stacey/deep-drive/reports/The-View-from-the-Driver's-Seat--Vmlldzo1MTg5NQ) for the more detailed model and [this GithHub repository by Boris Dayma](https://github.com/borisdayma/semantic-segmentation) for the best reference code.
![demoshot](https://i.imgur.com/GY969Y8.png). Note that everyone's runs from this demo are logged to a [shared Weights & Biases project page](https://app.wandb.ai/wandb/segment_demo) by default.



## Resources

* [Semantic segmentation API →](https://docs.wandb.com/library/log#logging-image-masks-semantic-segmentation)
* [Read more about this feature →](https://app.wandb.ai/stacey/deep-drive/reports/Image-Masks-for-Semantic-Segmentation--Vmlldzo4MTUwMw)
* [Read more about the problem and modeling approaches →](https://app.wandb.ai/stacey/deep-drive/reports/The-View-from-the-Driver's-Seat--Vmlldzo1MTg5NQ)
* [Repo: Semantic Segmentation for Self-Driving Cars→](https://github.com/borisdayma/semantic-segmentation)


## About

Weights & Biases helps you log experiments, visualize and analyze them faster, collaborate with others, and share your findings. Here we use Google Colab as a convenient hosted environment, but you can run your own training scripts from *any local or cloud setup* with W&B.

## Setup

In [None]:
!pip install -qq wandb fastai

In [None]:
from pathlib import Path
from fastai.vision import *
import wandb
from fastai.callbacks.hooks import *
from fastai.callback import Callback
import json

import wandb
from fastai.callback.wandb import WandbCallback
from functools import partialmethod
import PIL
import torch
import time

W&B – Login to your wandb account so you can log all your metrics

In [None]:
wandb.login()

In [None]:
# Download the training data: this is a subset of the Berkeley Deep Drive 100K dataset,
# which is available at https://bdd-data.berkeley.edu/ 
!curl -SL -qq https://storage.googleapis.com/wandb_datasets/BDD100K_seg_demo.zip > BDD100K_seg_demo.zip
!unzip -qq BDD100K_seg_demo.zip

Segmentation labels extracted from dataset source code

In [None]:
# See https://github.com/ucbdrive/bdd-data/blob/master/bdd_data/label.py
segmentation_classes = [
    'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
    'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
    'truck', 'bus', 'train', 'motorcycle', 'bicycle', 'void'
]

def labels():
    return {i:label in enumerate(segmentation_classes)}

Data paths

In [None]:
path_data = Path('segment_demo')
path_lbl = path_data / 'labels'
path_img = path_data / 'images'

Set low for faster iteration

In [None]:
USE_DATA_FRACTION=0.1

Associate a label to an input

In [None]:
get_y_fn = lambda x: path_lbl / x.parts[-2] / f'{x.stem}_train_id.png'

# Load data into train & validation sets
src = (SegmentationItemList.from_folder(path_img).use_partial_data(USE_DATA_FRACTION)
       .split_by_folder(train='train', valid='val')
       .label_from_func(get_y_fn, classes=segmentation_classes))

In [None]:
# fast.ai callback extension to log image masks
class LogImagesCallback(Callback):

  def __init__(self, learn):
    self.learn = learn
    self.num_to_log = 5

  def on_epoch_end(self, **kwargs):
    input_batch = self.learn.data.valid_ds[:self.num_to_log]
    mask_list = []
    table_data = []
    for i, img_pair in enumerate(input_batch):
      original_image = img_pair[0]
      # the raw background image as a numpy array
      bg_image = image2np(original_image.data*255).astype(np.uint8)
      # run the model on that image
      prediction = learn.predict(original_image)[0]
      prediction_mask = image2np(prediction.data).astype(np.uint8)

      # ground truth mask
      ground_truth = img_pair[1]
      true_mask = image2np(ground_truth.data).astype(np.uint8)
      # keep a list of composite images
      masked_img = wb_mask(bg_image, prediction_mask, true_mask)
      mask_list.append(masked_img)
      #add row id and image to table list
      table_data.append([i, masked_img])

    # log all composite images to W&B alongside a table containing those images
    masked_img_table = wandb.Table(data=table_data, columns=["id", "img"])
    wandb.log({"predictions" : mask_list, "prediction_table": masked_img_table})

In [None]:
# Accuracy metrics for a few different classes
# You could define more for other classes, or consider intersection over union, 
# which seems to give better performance in the latest version of this model
void_code = 19
# overall accuracy: across all classes, ignore unlabeled pixels
def acc(input, target):
    target = target.squeeze(1)
    mask = target != void_code
    try:
      i = (input.argmax(dim=1)[mask] == target[mask]).float()
      m_i = i.mean()
      return m_i
    except:
      return torch.tensor([0.0])

# only consider classes related to traffic sign and traffic lights
def traffic_acc(input, target):
    target = target.squeeze(1)
    mask_pole = target == 5
    mask_light = target == 6
    mask_sign = target == 7
    mask_traffic = mask_pole | mask_light | mask_sign
    try:
      i = (input.argmax(dim=1)[mask_traffic] == target[mask_traffic]).float()
      m_i = i.mean()
      return m_i
    except:
      return torch.tensor([0.0])

# only consider cars
def car_acc(input, target):
    target = target.squeeze(1)
    mask = target == 13
    try:
        intersection = input.argmax(dim=1)[mask] == target[mask]
        mean_intersection = intersection.float().mean()
        return mean_intersection
    except:
        return torch.tensor([0.0])

In [None]:
# This cell configures your experiment: you can modify the hyperparameters here,
# and make sure to rerun this cell for every new training run you launch
# Initialize W&B project 
wandb.init(project="semseg_demo")

# Define hyperparameters
config = wandb.config           # for shortening
config.framework = "fast.ai"    # AI framework used (for when we create other versions)
config.img_size = (360, 640)    # dimensions of resized image - can be 1 dim or tuple

config.batch_size = 2           # Batch size during training -- setting this super low to avoid CUDA OOM error
config.epochs = 4               # Number of epochs for training -- set this to 10+ for better results

config.encoder = "resnet18"     # could be resnet18 or alexnet (but watch out for CUDA memory)
encoder = models.resnet18

config.pretrained = True        # whether we use a frozen pre-trained encoder
config.weight_decay = 0.097     # weight decay applied on layers
config.bn_weight_decay = True   # whether weight decay is applied on batch norm layers
config.one_cycle = True         # use the "1cycle" policy -> https://arxiv.org/abs/1803.09820
config.learning_rate = 0.001    # learning rate

# Resize, augment, load in batch & normalize (so we can use pre-trained networks)
data = (src.transform(get_transforms(), size=config.img_size, tfm_y=True)
        .databunch(bs=config.batch_size)
        .normalize(imagenet_stats))

# Track how much data we actually use in this run
config.num_train = len(data.train_ds) 
config.num_valid = len(data.valid_ds)

In [None]:
%%wandb
# This cell launches the W&B experiment run and shows you the training progress
# in real-time

# Create model
learn = unet_learner(
    data,
    arch=encoder,
    pretrained=config.pretrained,
    metrics=[acc, car_acc, traffic_acc],
    wd=config.weight_decay,
    bn_wd=config.bn_weight_decay,
    callback_fns=partial(WandbCallback, monitor='acc'))

# Train
learn.fit_one_cycle(
    config.epochs // 2,
    max_lr=slice(config.learning_rate),
    callbacks=[LogImagesCallback(learn)])
learn.unfreeze()
learn.fit_one_cycle(
    config.epochs // 2,
    max_lr=slice(config.learning_rate / 100,
                 config.learning_rate / 10),
    callbacks=[LogImagesCallback(learn), WandbCallback(learn)])
wandb.run.finish()

# How to See Live Results in Shared Project
1. Check out the [project page](https://app.wandb.ai/wandb/neurips-demo/) to see your results in the shared project. 
1. Press 'option+space' to expand the runs table, comparing all the results from everyone who has tried this script. 
1. Click on the name of a run to dive deeper into that specific run on its own run page.

![project page](https://i.imgur.com/I1PM9YJ.png)


## Visualize Relationships

Use a parallel coordinates chart to see the relationship between hyperparameters and output metrics. Here, we can see how the learning rate and other metrics saved in "config" affect loss and accuracy.

![parallel coordinates plot](https://i.imgur.com/cg1uodx.png)

# More about Weights & Biases
We're always free for academics and open source projects. Email carey@wandb.com with any questions or feature suggestions. Here are some more resources:

1. [Documentation](http://docs.wandb.com) - Python docs
2. [Gallery](https://app.wandb.ai/gallery) - example reports in W&B
3. [Articles](https://www.wandb.com/articles) - blog posts and tutorials
4. [Community](bit.ly/wandb-forum) - join our Slack community forum