![header](https://i.imgur.com/sAPM7Yy.png)

# Learning to Segment Videos from Pre-Computed Features

Segmenting videos is challenging,
and even though the DAVIS dataset is nearly half a gigabyte in size,
it's not nearly big enough for a neural network to learn
everything about video segmentation from scratch.

In order to succeed, you'll need to take advantage of _pre-training_:
letting the network learn part of a task from a larger dataset
that demonstrates a related task,
and then learning the rest of the task from a smaller dataset.

Pre-training is ubiquitous in human and animal learning.
Before learning to read sheet music,
we often first learn to read natural language.
Before learning to read a language,
we learn related things:
how to speak that language,
how to recognize objects and symbols.
This speeds up our learning immensely.
We'd like to do the same for our neural network,
showing it some related data that will help it
learn to segment videos.

In particular, there are lots of great datasets of images out there,
and even more neural networks trained on those datasets.
These networks already know a lot about images,
which overlaps quite substantially with what they need to learn about videos
in order to segment them.

In this notebook,
we'll work through how to apply a pretrained image model,
[AlexNet](https://arxiv.org/abs/1404.5997),
to the video segmentation problem.

In [None]:
%%capture
!pip install "git+https://www.github.com/wandb/davis-contest.git#egg=contest[torch]"

In [None]:
from functools import lru_cache
import os

import numpy as np
import pytorch_lightning as pl
import skimage.io
import torch
import torchvision.models as models
import wandb

import contest
from contest.utils import clips, paths

# Working with a Pretrained Model

The basic idea is that the pre-trained model extracts _features_,
a representation of the inputs that pulls out the useful information
from the cacophony of pixels.
In convolutional networks,
these features are detected throughout the entire image,
creating a _feature map_,
a value of the feature for each spatial location.

For example, one feature might be "contains a dog":
each pixel value in this feature map is large where the associated
region of the image appears, to the network,
to contain a dog.

For more on the features extracted by convolutional networks,
see [this paper from the Distill project](https://distill.pub/2017/feature-visualization/)
or dive deeper with the
[Circuits series from OpenAI](https://distill.pub/2020/circuits/).

From these features -- which are much smaller than the input frame --
a separate neural network learns to build segmentations.
For that smaller network, DAVIS can provide more than enough data!

## Data Engineering

Contemporary ML engineering projects
generally have two separate components:

1) a _data engineering_ component,
which involves fetching data
and getting it into the GPU,
and

2) a _model engineering component_,
which involves building a model
that consumes that data
through training.

Getting both pieces right
is critical for a successful ML project.
Below, we dive into how pre-training
impacts each of these components.


### Why Use Precomputed Features? 

Two major bottlenecks for the data engineering pipeline
are reading data from disk
and transferring data from CPU RAM to GPU RAM.

We can get around both bottlenecks by loading the entirety
of our dataset into GPU RAM at once.

But GPU RAM also needs to hold our model,
its intermediate computations
and its gradients,
which means space is at a premium.

We can save a tremendous amount of space
by recognizing that, for training,
we don't actually need the videos themselves.
We aren't training the network that computes the features,
and so those features will be fixed throughout training.

So let's instead treat the features as the input data,
rather than videos.
They're much smaller and so will fit comfortably in memory
with everything else.

This process is known as _feature extraction_
or using _precomputed features_.

The code below constructs a `Dataset` object
and an associated `LightningDataModule`
that apply a `featurizer` network to their inputs.

For more on `LightningDataModule`s and
data engineering in PyTorch,
see [this video](https://www.youtube.com/watch?v=L---MBeSXFw).

In [None]:
class FeaturizedDataset(torch.utils.data.Dataset):

  def __init__(self, featurized_xs, paths_df=None, mask_transform=None):

    self.featurized_xs = featurized_xs

    self.paths_df = paths_df
    if self.paths_df is not None:
      self.annotation_paths = self.paths_df["annotation"]
    else:
      self.annotation_paths is None
    self.mask_transform = mask_transform

    self.len = len(self.featurized_xs)

  def __len__(self):
    return self.len

  @lru_cache(maxsize=None)
  def __getitem__(self, idx):
    x = self.featurized_xs[idx]

    if torch.is_tensor(idx):
      idx = idx.to_list()

    if self.annotation_paths is None:
      return x
    else:
      annotation_name = self.annotation_paths.iloc[idx]
      annotation = skimage.io.imread(annotation_name)
      if self.mask_transform is not None:
        annotation = self.mask_transform(annotation)
  
      return x, annotation

  @staticmethod
  def _apply_featurizer(featurizer, dataloader):
    featurized_xs = []
    for batch in dataloader:
      xs, ys = batch
      featurized_xs.append(featurizer.forward(xs))
    featurized_xs = torch.cat(featurized_xs)

    return featurized_xs


  @classmethod
  def from_raw_data(cls, featurizer, raw_dataloader,
                    paths_df=None, mask_transform=None):
    featurized_xs = FeaturizedDataset._apply_featurizer(featurizer, raw_dataloader)
    return cls(featurized_xs, paths_df, mask_transform)


class FeaturizedDataModule(pl.LightningDataModule):

  def __init__(self, featurizer, paths_df, has_annotations=True, num_workers=0,
               image_transform=None, mask_transform=None, batch_size=None,
               featurizer_batch_size=None):
    super().__init__()

    self.paths_df = paths_df
    self.has_annotations = has_annotations
    self.num_workers = num_workers 
    self.featurizer = featurizer

    if image_transform is None:
      self.image_transform = contest.torch.data.default_image_transform
    else:
      self.image_transform = self.image_transform
    
    if mask_transform is None:
      self.mask_transform = contest.torch.data.default_mask_transform
    else:
      self.mask_transform = self.mask_transform

    if batch_size is None:
      self.batch_size = len(paths_df)
    else:
      self.batch_size = batch_size

    if featurizer_batch_size is None:
      self.featurizer_batch_size = 32
    else:
      self.featurizer_batch_size = featurizer_batch_size

  def setup(self, stage=None):
      self.raw_dataset = contest.torch.data.VidSegDataset(
        self.paths_df, self.has_annotations,
        image_transform=self.image_transform,
        mask_transform=self.mask_transform)
      self.raw_dataloader = torch.utils.data.DataLoader(
          self.raw_dataset, batch_size=self.featurizer_batch_size,
          num_workers=self.num_workers)

      self.featurized_dataset = FeaturizedDataset.from_raw_data(
        self.featurizer, self.raw_dataloader, self.paths_df,
        mask_transform=self.mask_transform)

  def prepare_data(self, stage=None):
    if stage == "fit" or stage is None:
      self.train_dataset = self.featurized_dataset

  def train_dataloader(self):
    return torch.utils.data.DataLoader(self.featurized_dataset, batch_size=self.batch_size,
                                       num_workers=self.num_workers)

### Applying the Pretrained Model to the Training Data

The deep learning community has developed numerous tools
for sharing and distributing pre-trained models.

Specifically for computer vision, the
[`torchvision.models`](https://pytorch.org/vision/0.8/models.html)
module provides easy access to a variety of
widely-used and performant pre-trained convolutional neural networks.

In [None]:
def get_alexnet():
  alexnet = models.alexnet(pretrained=True)
  alexnet.eval = True    
  for param in alexnet.parameters():
      param.requires_grad = False
      
  featurizer = alexnet.features
  return featurizer

In order to apply the pre-trained model to the training data,
we need the training data.

The training data for the contest is stored and distributed using
Weights & Biases [Artifacts](https://docs.wandb.ai/artifacts/api).
For more on using Artifacts, see the
[starter colabs](https://github.com/wandb/davis-contest/tree/main/colabs).

In [None]:
# picking out the training data artifact by name

entity = "wandb"  # artifacts are associated with an entity -- a user or team
project = "davis"  # artifacts are associated with a project -- a collection of ML experiments
split = "train"  # the train and val data are both stored in the same format
tag = "contest"  # different versions of an Artifact have different tags

training_data_artifact_id = os.path.join(entity, project, f"davis2016-{split}") + ":" + tag
training_data_artifact_id

In [None]:
def apply_featurization(data_artifact_id, featurizer, output_artifact_name):
  if featurizer == "alexnet":
    featurizer = get_alexnet()
  else:
    raise ValueError(f"unknown featurizer: {featurizer}")

  data_artifact = run.use_artifact(data_artifact_id)
  paths_df = paths.artifact_paths(data_artifact)

  fdm = FeaturizedDataModule(featurizer, paths_df)
  fdm.setup()

  output_artifact = wandb.Artifact(output_artifact_name,
                                   type="featurized-data")

  featurized_array = fdm.featurized_dataset.featurized_xs.numpy()
  np.save("features.npy", featurized_array)
  output_artifact.add_file("features.npy", "features_array")

  paths_df.to_json("paths.json")
  output_artifact.add_file("paths.json")

  wandb.run.log_artifact(output_artifact)

  try:
    output_artifact.wait()
  except AttributeError:
    pass


  return "/".join([wandb.run.entity, wandb.run.project, output_artifact.name])

In [None]:
config = {"featurizer": "alexnet"}

with wandb.init(project=project, job_type="featurize", config=config) as run:

  output_artifact_name = f"{wandb.config['featurizer']}-featurized-train"

  output_artifact_id = apply_featurization(
    training_data_artifact_id, wandb.config["featurizer"], output_artifact_name) 

### Loading Featurized Data

Now that we've precomputed the features,
we don't need to keep the original data around.

The cell below defines
a new `LightningDataModule` and `Dataset`
that make use of the saved precomputed features from above,
rather than working from the original data.

We'll use these below in our training loop.

In [None]:
class PrecomputedFeaturesDataModule(pl.LightningDataModule):

  def __init__(self, features_file, annotation_files=None, batch_size=32):
    self.batch_size = batch_size
    self.features_file = features_file
    self.annotation_files = annotation_files

  def setup(self):
    self.dataset = PrecomputedFeaturesDataset(
      self.features_file, self.annotation_files)

  def train_dataloader(self):
    return torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size)


class PrecomputedFeaturesDataset(torch.utils.data.Dataset):

  def __init__(self, features_file, annotation_files=None, mask_transform=None):
    self.features_file = features_file
    self.annotation_files = annotation_files
    if mask_transform is None:
      mask_transform = contest.torch.data.default_mask_transform
    self.mask_transform = mask_transform

    self.load_features(self.features_file)

  def __len__(self):
    return len(self.featurized_xs)

  @lru_cache(maxsize=None)
  def __getitem__(self, idx):
    x = self.featurized_xs[idx]

    if torch.is_tensor(idx):
      idx = idx.to_list()
    if self.annotation_files is None:
      return x
    else:
      annotation_name = self.annotation_files.iloc[idx]
      annotation = skimage.io.imread(annotation_name)
      annotation = self.mask_transform(annotation)

      return x, annotation

  def load_features(self, features_file):
    self.featurized_xs = torch.Tensor(np.load(features_file))

## Model Engineering

With our big network doing most of the work for us,
we can get pretty good performance without doing too much ourselves.

Here, we build the simplest possible network on top:
a single linear (convolutional) layer,
followed by a `sigmoid` function so that the results are scaled appropriately.

We use the `B`inary `C`ross `E`ntropy `Loss` function,
which penalizes the network especially heavily for confidently segmenting
areas where there is no subject in the ground truth.

You might try others!

In [None]:
model_name = "simple-decoder"

class SimpleDecoder(pl.LightningModule):

  def __init__(self, target_size=(480, 854)):
    super().__init__()

    self.conv = torch.nn.Conv2d(256, 1, kernel_size=3)
    self.resize = torch.nn.AdaptiveAvgPool2d(target_size)
    self.cost = torch.nn.BCELoss()

  def forward(self, xs):
    xs = self.conv(xs)
    xs = torch.sigmoid(xs)
    return self.resize(xs)

  def loss(self, outs, ys):
    return self.cost(outs, ys)

### Training and Logging with Weights & Biases

In the cell below, we define a `train`ing function
that glues together our precomputed features and our `Simple` model.

Included are some Weights & Biases logging tools:
in particular, tracking the predictions and the ground truth
so that we can look at how the network's outputs compare to the correct answers
and how they develop during training.

In [None]:
def train(model, optimizer, dataloader, model_name, steps=1, log_freq=10, device="cuda",
          run=None):

  if run is None:
    run = wandb.init(project="davis", job_type="train")

  model.train = True
  model.to(device)
  
  model_artifact = wandb.Artifact(model_name, type="trained-model",
                                  metadata={})

  class_labels = {0: "background", 1: "object"}
  dataiterator = iter(dataloader)

  for step in range(steps):
    try:
      xs, ys = next(dataiterator)
    except StopIteration:
      dataiterator = iter(dataloader)
      xs, ys = next(dataiterator)
    xs, ys = xs.to(device), ys.to(device)
      
    outs = model(xs)
    loss = model.loss(outs, ys)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if not step % log_freq:
      img, out, target = (xs[0].detach().to("cpu"),
                          outs[0].detach().to("cpu"),
                          ys[0].detach().to("cpu"))
      mask = torch.round(out)
      img = img.permute(1, 2, 0)
      mask = torch.squeeze(mask)
      target = torch.squeeze(target)
      img = target  # for featurized models

      mask_img = wandb.Image(img.numpy(), masks={
          "predictions": {
              "mask_data": mask.numpy(),
              "class_labels": class_labels
          },
          "ground_truth": {
              "mask_data": target.numpy(),
              "class_labels": class_labels
          }
      })
      wandb.log({"loss": float(loss),
                  "prediction": mask_img},
                step=step)

      filename = f"model-{str(step).zfill(8)}.pt"
      torch.save(model.state_dict(), filename)

      model_artifact.add_file(filename)

  model_artifact.add_file(filename, "final_model")

  run.log_artifact(model_artifact)

In [None]:
config = {"batch_size": 32,
          "lr": 5e-4,
          "betas": (0.9, 0.999),
          "steps": 101,
          "log_freq": 10,
          "featurizer": "alexnet"}


with wandb.init(project="davis", job_type="train", config=config) as run:

  featurized_artifact_id = f"davis/{wandb.config['featurizer']}-featurized-train:latest"
  precomputed_features = run.use_artifact(featurized_artifact_id)
  precomputed_features_dir = precomputed_features.download()
  precomputed_features_path = os.path.join(precomputed_features_dir, "features_array")

  raw_data_artifact = run.use_artifact(training_data_artifact_id)
  raw_paths_df = paths.artifact_paths(raw_data_artifact)

  pcfdm = PrecomputedFeaturesDataModule(precomputed_features_path,
                                        raw_paths_df["annotation"],
                                        batch_size=wandb.config["batch_size"])
  pcfdm.setup()
  tdl = pcfdm.train_dataloader()

  model = SimpleDecoder()
  wandb.watch(model, log_freq=wandb.config["log_freq"])
  optimizer = torch.optim.Adam(model.parameters(),
                               lr=wandb.config["lr"],
                               betas=wandb.config["betas"])

  train(model, optimizer, tdl, model_name,
        steps=wandb.config["steps"], log_freq=wandb.config["log_freq"],
        run=run)

# Packaging Results for Evaluation

Submissions to the contest need to be put into a particular format in order to be considered and evaluated.

Below, we'll package up the results of our pre-trained model
into this format.

See the [contest instructions](https://github.com/wandb/davis-contest)
and the [starter notebooks](https://github.com/wandb/davis-contest/tree/main/colabs)
for more details on this format.

In [None]:
split = "val"
validation_data_artifact_id = "/".join([entity, project, f"davis2016-{split}"]) + ":" + tag
validation_data_artifact_id

## Featurizing the Validation Set

In [None]:
config = {"featurizer": "alexnet"}

with wandb.init(project=project, job_type="featurize", config=config) as run:

  output_artifact_name = f"{wandb.config['featurizer']}-featurized-val"

  apply_featurization(
    validation_data_artifact_id, wandb.config["featurizer"], output_artifact_name) 

## Running the Model on the Featurized Validation Set

In [None]:
model_artifact_id = f"davis/{model_name}:latest"

In [None]:
output_dir = os.path.join("outputs")
!rm -rf output_dir
!mkdir -p {output_dir}

result_artifact_name = model_name + "-result"

config = {"batch_size": 32,
          "featurizer": "alexnet",
          "model": model_name}

with wandb.init(project="davis", job_type="run-val", config=config) as run:

  # get and set up data
  featurized_artifact_id = f"davis/{wandb.config['featurizer']}-featurized-val:latest"
  precomputed_features = run.use_artifact(featurized_artifact_id)
  precomputed_features_dir = precomputed_features.download()
  precomputed_features_path = os.path.join(precomputed_features_dir, "features_array")

  raw_data_artifact = run.use_artifact(validation_data_artifact_id)
  raw_paths_df = paths.artifact_paths(raw_data_artifact)

  pcfdm = PrecomputedFeaturesDataModule(precomputed_features_path,
                                        batch_size=wandb.config["batch_size"])
  pcfdm.setup()
  tdl = pcfdm.train_dataloader()

  # get and set up featurizer and model
  if wandb.config["featurizer"] == "alexnet":
    featurizer = get_alexnet()
  else:
    raise ValueError(f"unknown featurizer {wandb.config['featurizer']}")
  model = contest.torch.utils.load_model_from_artifact(model_artifact_id, SimpleDecoder)

  # profiling metadata
  ## don't forget to include the parameters from your featurizing model!
  nparams = contest.torch.profile.count_params(featurizer) +\
            contest.torch.profile.count_params(model) 

  profiling_metadata = {"nparams": nparams}
  wandb.log(profiling_metadata)

  output_paths = contest.torch.evaluate.run(model, tdl, len(pcfdm.dataset), output_dir)

  result_artifact = contest.evaluate.make_result_artifact(
    output_paths, result_artifact_name, metadata=profiling_metadata
  )

  run.log_artifact(result_artifact)