<a href="https://colab.research.google.com/github/HyamsG/FlexibleRegularization/blob/master/torchvision_finetuning_instance_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TorchVision 0.3 Object Detection finetuning tutorial

For this tutorial, we will be finetuning a pre-trained [Mask R-CNN](https://arxiv.org/abs/1703.06870) model in the [*Penn-Fudan Database for Pedestrian Detection and Segmentation*](https://www.cis.upenn.edu/~jshi/ped_html/). It contains 170 images with 345 instances of pedestrians, and we will use it to illustrate how to use the new features in torchvision in order to train an instance segmentation model on a custom dataset.

First, we need to install `pycocotools`. This library will be used for computing the evaluation metrics following the COCO metric for intersection over union.

In [16]:
%%shell

pip install cython
# Install pycocotools, the version by default in Colab
# has a bug fixed in https://github.com/cocodataset/cocoapi/pull/354
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'

Collecting git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI
  Cloning https://github.com/cocodataset/cocoapi.git to /tmp/pip-req-build-res66lbn
  Running command git clone -q https://github.com/cocodataset/cocoapi.git /tmp/pip-req-build-res66lbn
Building wheels for collected packages: pycocotools
  Building wheel for pycocotools (setup.py) ... [?25l[?25hdone
  Created wheel for pycocotools: filename=pycocotools-2.0-cp36-cp36m-linux_x86_64.whl size=265561 sha256=02403bcfbd84e0501c3b8fae797d019102e72412323dc9c639d740d9785c4f16
  Stored in directory: /tmp/pip-ephem-wheel-cache-lj4antbe/wheels/90/51/41/646daf401c3bc408ff10de34ec76587a9b3ebfac8d21ca5c3a
Successfully built pycocotools
Installing collected packages: pycocotools
  Found existing installation: pycocotools 2.0
    Uninstalling pycocotools-2.0:
      Successfully uninstalled pycocotools-2.0
Successfully installed pycocotools-2.0




## Defining the Dataset

The [torchvision reference scripts for training object detection, instance segmentation and person keypoint detection](https://github.com/pytorch/vision/tree/v0.3.0/references/detection) allows for easily supporting adding new custom datasets.
The dataset should inherit from the standard `torch.utils.data.Dataset` class, and implement `__len__` and `__getitem__`.

The only specificity that we require is that the dataset `__getitem__` should return:

* image: a PIL Image of size (H, W)
* target: a dict containing the following fields
    * `boxes` (`FloatTensor[N, 4]`): the coordinates of the `N` bounding boxes in `[x0, y0, x1, y1]` format, ranging from `0` to `W` and `0` to `H`
    * `labels` (`Int64Tensor[N]`): the label for each bounding box
    * `image_id` (`Int64Tensor[1]`): an image identifier. It should be unique between all the images in the dataset, and is used during evaluation
    * `area` (`Tensor[N]`): The area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes.
    * `iscrowd` (`UInt8Tensor[N]`): instances with `iscrowd=True` will be ignored during evaluation.
    * (optionally) `masks` (`UInt8Tensor[N, H, W]`): The segmentation masks for each one of the objects
    * (optionally) `keypoints` (`FloatTensor[N, K, 3]`): For each one of the `N` objects, it contains the `K` keypoints in `[x, y, visibility]` format, defining the object. `visibility=0` means that the keypoint is not visible. Note that for data augmentation, the notion of flipping a keypoint is dependent on the data representation, and you should probably adapt `references/detection/transforms.py` for your new keypoint representation

If your model returns the above methods, they will make it work for both training and evaluation, and will use the evaluation scripts from pycocotools.

Additionally, if you want to use aspect ratio grouping during training (so that each batch only contains images with similar aspect ratio), then it is recommended to also implement a `get_height_and_width` method, which returns the height and the width of the image. If this method is not provided, we query all elements of the dataset via `__getitem__` , which loads the image in memory and is slower than if a custom method is provided.


### Writing a custom dataset for Penn-Fudan

Let's write a dataset for the Penn-Fudan dataset.

First, let's download and extract the data, present in a zip file at https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip

In [None]:
%%shell

# download the Penn-Fudan dataset
wget https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip .
# extract it in the current folder
unzip PennFudanPed.zip

--2021-01-26 21:18:40--  https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip
Resolving www.cis.upenn.edu (www.cis.upenn.edu)... 158.130.69.163, 2607:f470:8:64:5ea5::d
Connecting to www.cis.upenn.edu (www.cis.upenn.edu)|158.130.69.163|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 53723336 (51M) [application/zip]
Saving to: ‘PennFudanPed.zip.1’


2021-01-26 21:18:40 (222 MB/s) - ‘PennFudanPed.zip.1’ saved [53723336/53723336]

--2021-01-26 21:18:40--  http://./
Resolving . (.)... failed: No address associated with hostname.
wget: unable to resolve host address ‘.’
FINISHED --2021-01-26 21:18:40--
Total wall clock time: 0.3s
Downloaded: 1 files, 51M in 0.2s (222 MB/s)
Archive:  PennFudanPed.zip
replace PennFudanPed/added-object-list.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

Let's have a look at the dataset and how it is layed down.

The data is structured as follows
```
PennFudanPed/
  PedMasks/
    FudanPed00001_mask.png
    FudanPed00002_mask.png
    FudanPed00003_mask.png
    FudanPed00004_mask.png
    ...
  PNGImages/
    FudanPed00001.png
    FudanPed00002.png
    FudanPed00003.png
    FudanPed00004.png
```

Here is one example of an image in the dataset, with its corresponding instance segmentation mask

In [None]:
from PIL import Image
Image.open('PennFudanPed/PNGImages/FudanPed00001.png')

In [None]:
mask = Image.open('PennFudanPed/PedMasks/FudanPed00001_mask.png')
# each mask instance has a different color, from zero to N, where
# N is the number of instances. In order to make visualization easier,
# let's adda color palette to the mask.
mask.putpalette([
    0, 0, 0, # black background
    255, 0, 0, # index 1 is red
    255, 255, 0, # index 2 is yellow
    255, 153, 0, # index 3 is orange
])
mask

So each image has a corresponding segmentation mask, where each color correspond to a different instance. Let's write a `torch.utils.data.Dataset` class for this dataset.

In [None]:
import os
import numpy as np
import torch
import torch.utils.data
from PIL import Image


class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))

    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = Image.open(mask_path)

        mask = np.array(mask)
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

That's all for the dataset. Let's see how the outputs are structured for this dataset

In [None]:
dataset = PennFudanDataset('PennFudanPed/')
dataset[0]

So we can see that by default, the dataset returns a `PIL.Image` and a dictionary
containing several fields, including `boxes`, `labels` and `masks`.

## Defining your model

In this tutorial, we will be using [Mask R-CNN](https://arxiv.org/abs/1703.06870), which is based on top of [Faster R-CNN](https://arxiv.org/abs/1506.01497). Faster R-CNN is a model that predicts both bounding boxes and class scores for potential objects in the image.

![Faster R-CNN](https://raw.githubusercontent.com/pytorch/vision/temp-tutorial/tutorials/tv_image03.png)

Mask R-CNN adds an extra branch into Faster R-CNN, which also predicts segmentation masks for each instance.

![Mask R-CNN](https://raw.githubusercontent.com/pytorch/vision/temp-tutorial/tutorials/tv_image04.png)

There are two common situations where one might want to modify one of the available models in torchvision modelzoo.
The first is when we want to start from a pre-trained model, and just finetune the last layer. The other is when we want to replace the backbone of the model with a different one (for faster predictions, for example).

Let's go see how we would do one or another in the following sections.


### 1 - Finetuning from a pretrained model

Let's suppose that you want to start from a model pre-trained on COCO and want to finetune it for your particular classes. Here is a possible way of doing it:
```
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 
```

### 2 - Modifying the model to add a different backbone

Another common situation arises when the user wants to replace the backbone of a detection
model with a different one. For example, the current default backbone (ResNet-50) might be too big for some applications, and smaller models might be necessary.

Here is how we would go into leveraging the functions provided by torchvision to modify a backbone.

```
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# FasterRCNN needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
backbone.out_channels = 1280

# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios 
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
                                   aspect_ratios=((0.5, 1.0, 2.0),))

# let's define what are the feature maps that we will
# use to perform the region of interest cropping, as well as
# the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an
# OrderedDict[Tensor], and in featmap_names you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                output_size=7,
                                                sampling_ratio=2)

# put the pieces together inside a FasterRCNN model
model = FasterRCNN(backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)
```

### An Instance segmentation model for PennFudan Dataset

In our case, we want to fine-tune from a pre-trained model, given that our dataset is very small. So we will be following approach number 1.

Here we want to also compute the instance segmentation masks, so we will be using Mask R-CNN:

In [None]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

      
def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

In [None]:
import numpy as np
import attr
import torch

@attr.s
class OnlineAvg:
    """
    online variance calculation
    For a new value newValue, compute the new count, new avg, the new M2.
    avg accumulates the avg of the entire dataset
    M2 aggregates the squared distance from the avg
    count aggregates the number of samples seen so far
    """
    dim = attr.ib()
    static_calculation = attr.ib(True)
    package = attr.ib('torch')
    reinitiate_every_step = attr.ib(True)
    initial_param = attr.ib(None)

    def __attrs_post_init__(self):
        self.tensor_package = torch if self.package == 'torch' else np
        self.initial_param = self.tensor_package.clone(self.initial_param)
        self.count = 0
        self.avg = self.initial_param #or self.tensor_package.zeros(self.dim)
        self.static_avg = self.tensor_package.zeros(self.dim)

    def update(self, new_value):
        self.count += 1
        delta = new_value - self.avg
        self.avg += delta / self.count

    def update_static_mean(self):
        self.static_avg = self._get_avg()
        if self.reinitiate_every_step:
            self.count = 0
            self.avg = self.tensor_package.zeros(self.dim)
            self.static_avg = self.tensor_package.zeros(self.dim)

    def _get_avg(self):
        return self.avg

    def get_static_mean(self):
        if self.static_calculation:
            return self.static_avg
        return self._get_avg()

#
import numpy as np
import attr
import torch

@attr.s
class Welford:
    """
    online variance calculation
    For a new value newValue, compute the new count, new avg, the new M2.
    avg accumulates the avg of the entire dataset
    M2 aggregates the squared distance from the avg
    count aggregates the number of samples seen so far
    """
    dim = attr.ib()
    static_calculation = attr.ib(True)
    divide_var_by_mean_var = attr.ib(True)
    var_normalizer = attr.ib(1)
    # device = attr.ib('cpu')
    package = attr.ib('torch')
    reinitiate_every_step = attr.ib(True)
    initial_param = attr.ib(None)

    def __attrs_post_init__(self):
        self.tensor_package = torch if self.package == 'torch' else np
        self.initial_param = self.tensor_package.clone(self.initial_param)
        self.count = 0
        self.mean = self.initial_param #if  or self.tensor_package.zeros(self.dim)
        self.M2 = self.tensor_package.zeros(self.dim)
        self.var = self.tensor_package.ones(self.dim)

        
    def update(self, new_value):
        self.count += 1
        delta = new_value - self.mean
        self.mean += delta / self.count  # todo: is this coordinate-wise? Yes, count is integer?
        delta2 = new_value - self.mean
        self.M2 += delta * delta  # todo: is this coordinate-wise? Yes it is.

    def update_var(self):
        self.var = self._get_var()
        if self.reinitiate_every_step:
            self.count = 0
            self.mean = self.tensor_package.zeros(self.dim)
            self.M2 = self.tensor_package.zeros(self.dim)
            self.var = self.tensor_package.ones(self.dim)

    def get_mean(self):
        return self.mean
    
    def get_mle_var(self):
        var = self.M2 / self.count - 1
        if self.divide_var_by_mean_var:
            var = var / self.tensor_package.mean(var)  # todo: is this coordinate-wise??
        var = var * self.var_normalizer
        return var


    def get_var(self):
        if self.static_calculation:
            return self.var
        return self._get_var()
    
    def _get_var(self):
        var = self.M2 / max((self.count - 1), 1)
        if self.divide_var_by_mean_var:
            var = var/self.tensor_package.mean(var)
        var = var * self.var_normalizer
        return var

#

import torch
from torch.optim.optimizer import Optimizer, required

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')


class SGD(Optimizer):
    r"""Implements stochastic gradient descent (optionally with momentum).

    Nesterov momentum is based on the formula from
    `On the importance of initialization and momentum in deep learning`__.

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()

    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf

    .. note::
        The implementation of SGD with Momentum/Nesterov subtly differs from
        Sutskever et. al. and implementations in some other frameworks.

        Considering the specific case of Momentum, the update can be written as

        .. math::
            \begin{aligned}
                v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
            \end{aligned}

        where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
        parameters, gradient, velocity, and momentum respectively.

        This is in contrast to Sutskever et. al. and
        other frameworks which employ an update of the form

        .. math::
            \begin{aligned}
                v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
                p_{t+1} & = p_{t} - v_{t+1}.
            \end{aligned}

        The Nesterov version is analogously modified.
    """

    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, adaptive_var_weight_decay=False, iter_length=100, device=device,
                 inverse_var=False, adaptive_avg_reg=False, logger=None, static_var_calculation=True,
                 uniform_prior_strength=0.5):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)
        self.online_param_var_dict, self.avg_param_dict = \
            self.create_online_param_var_dict(
                adaptive_var_weight_decay=adaptive_var_weight_decay, adaptive_avg_reg=adaptive_avg_reg,
                static_var_calculation=static_var_calculation)

        self.num_of_steps = 0
        if adaptive_var_weight_decay or adaptive_avg_reg:
            self.iter_length = iter_length
            self.device = device
            self.inverse_var = inverse_var
            self.uniform_prior_strength = uniform_prior_strength
            self.logger = logger
        else:
            self.iter_length = None
            self.device = None
            self.inverse_var = None
            self.uniform_prior_strength = None
            self.logger = None
        # if adaptive_avg_reg:
        #     self.avg_dict =
        # else:
        #     self.avg_dict = False

    def __setstate__(self, state):
        super(SGD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def create_online_param_var_dict(self, adaptive_var_weight_decay, adaptive_avg_reg, static_var_calculation):
        # todo: implement in Pytorch instead numpy
        online_param_var = {} if adaptive_var_weight_decay else None
        avg_param_dict = {} if adaptive_avg_reg else None
        for group_index, param_group in enumerate(self.param_groups):
            for param_index, param in enumerate(param_group['params']):
                param_name = param.name
                if not param_name:
                    param_name = (group_index, param_index)
                # self.params[param_name] = param#.astype(dtype)
                # if self.adaptive_var_reg and 'W' in k:  # or (self.adaptive_dropconnect and k in ('W1', 'W2')):
                    # if self.variance_calculation_method == 'welford':
                if adaptive_var_weight_decay:
                    online_param_var[param_name] = Welford(dim=param.shape, static_calculation=static_var_calculation, package='torch', initial_param=param)
                if adaptive_avg_reg:
                    avg_param_dict[param_name] = OnlineAvg(dim=param.shape, static_calculation=static_var_calculation, package='torch', initial_param=param)
        return online_param_var, avg_param_dict

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group_index, group in enumerate(self.param_groups):
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for parameter_index, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                d_p = p.grad

                if self.num_of_steps > 0 and self.num_of_steps % 100 == 0:  # self.iter_length == 0:
                    parameter_name = (group_index, parameter_index)
                    param_l2 = torch.norm(p)
                    if self.logger:
                        self.logger.report_scalar(
                            title=f"parameter l2, {weight_decay}", series=str(parameter_name),
                            value=float(param_l2), iteration=self.num_of_steps)

                if weight_decay != 0:
                    if self.online_param_var_dict:
                        parameter_name = (group_index, parameter_index)
                        var_tensor = self.online_param_var_dict[parameter_name].get_var().to(device=self.device)
                        if not self.inverse_var:
                            var_tensor = torch.inverse(var_tensor)
                        # reg_p = d_p.add(self.avg_param_dict[parameter_name].get_static_mean.to(device=self.device), alpha=-1)
                        # reg_p = reg_p.mul(reg_p)
                        reg_p = p.mul(var_tensor)  # todo: does it yields per-coordinate multiplication?  YES IT IS!
                        d_p = d_p.add(reg_p, alpha=weight_decay*(1-self.uniform_prior_strength))
                        d_p = d_p.add(p, alpha=weight_decay * self.uniform_prior_strength)
                    else:
                        d_p = d_p.add(p, alpha=weight_decay)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    if nesterov:
                        d_p = d_p.add(buf, alpha=momentum)
                    else:
                        d_p = buf

                p.add_(d_p, alpha=-group['lr'])
                if self.online_param_var_dict and weight_decay != 0:
                    self.online_param_var_dict[parameter_name].update(p.to(device=self.device))
                    if self.num_of_steps > 0 and self.num_of_steps % self.iter_length == 0:
                    # if self.num_of_steps > self.iter_length and self.num_of_steps % self.iter_length == 0: #don't update first iter
                        self.online_param_var_dict[parameter_name].update_var()
                        print("updating var")
                        # report var
                        if self.logger:
                        # logger = trains.Task.current_task().get_logger()
                            var_calculator = self.online_param_var_dict[parameter_name]
                            d_var = var_calculator.M2 / (var_calculator.count - 1)  #  is this element-wise?
                            self.logger.report_scalar(
                                title=f"parameter variance, {weight_decay}", series=str(parameter_name),
                                value=float(d_var.mean()), iteration=self.num_of_steps)

        self.num_of_steps += 1
        return loss

    def update_param_variance_online(self, iteration):
        if not self.model.static_variance_update:
            return
        # logger = trains.Task.current_task().get_logger()
        for param_name in self.online_param_var:
            self.online_param_var_dict[param_name].update_var()
            # var_calculator = self.model.online_param_var[param_name]
            # d_var = var_calculator.dynamic_var if \
            # self.model.variance_calculation_method == 'GMA' \
            # else var_calculator.M2 / (var_calculator.count - 1)
            # logger.report_scalar(
            #     title=f"parameter variance, {self.model.reg}", series=param_name, value=np.average(d_var), iteration=iteration)
            # if self.model.adaptive_dropconnect:
            #     var = self.model.online_param_var[param_name].get_var()
            #     droconnect_value = 1/2 + np.sqrt(1-4*var) / 2
            #     dropconnect_value = np.nan_to_num(droconnect_value, nan=0.5)
            #     if self.model.divide_var_by_mean_var:
            #         dropconnect_value = dropconnect_value / np.mean(dropconnect_value)
            #     dropconnect_value = dropconnect_value * self.model.dropconnect
            #     self.model.dropconnect_param['adaptive_p'][param_name] = dropconnect_value



That's it, this will make model be ready to be trained and evaluated on our custom dataset.

## Training and evaluation functions

In `references/detection/,` we have a number of helper functions to simplify training and evaluating detection models.
Here, we will use `references/detection/engine.py`, `references/detection/utils.py` and `references/detection/transforms.py`.

Let's copy those files (and their dependencies) in here so that they are available in the notebook

In [None]:
%%shell

# Download TorchVision repo to use some files from
# references/detection
git clone https://github.com/pytorch/vision.git
cd vision
git checkout v0.3.0

cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/coco_eval.py ../
cp references/detection/engine.py ../
cp references/detection/coco_utils.py ../



Let's write some helper functions for data augmentation / transformation, which leverages the functions in `refereces/detection` that we have just copied:


In [None]:
from engine import train_one_epoch, evaluate
import utils
import transforms as T


def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

#### Note that we do not need to add a mean/std normalization nor image rescaling in the data transforms, as those are handled internally by the Mask R-CNN model.

### Putting everything together

We now have the dataset class, the models and the data transforms. Let's instantiate them

In [None]:
# use our dataset and defined transformations
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
dataset_test = PennFudanDataset('PennFudanPed', get_transform(train=False))

# split the dataset in train and test set
torch.manual_seed(1)
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=4,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

Now let's instantiate the model and the optimizer

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - background and person
num_classes = 2

# get the model using our helper function
model = get_instance_segmentation_model(num_classes)
# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005,
                adaptive_avg_reg = False,
                adaptive_var_weight_decay = True, iter_length = 30)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

And now let's train the model for 10 epochs, evaluating at the end of every epoch.

In [None]:
# let's train it for 10 epochs
num_epochs = 20

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

Now that training has finished, let's have a look at what it actually predicts in a test image

In [None]:
# pick one image from the test set
img, _ = dataset_test[0]
# put the model in evaluation mode
model.eval()
with torch.no_grad():
    prediction = model([img.to(device)])

Printing the prediction shows that we have a list of dictionaries. Each element of the list corresponds to a different image. As we have a single image, there is a single dictionary in the list.
The dictionary contains the predictions for the image we passed. In this case, we can see that it contains `boxes`, `labels`, `masks` and `scores` as fields.

In [None]:
prediction

Let's inspect the image and the predicted segmentation masks.

For that, we need to convert the image, which has been rescaled to 0-1 and had the channels flipped so that we have it in `[C, H, W]` format.

In [None]:
Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())

And let's now visualize the top predicted segmentation mask. The masks are predicted as `[N, 1, H, W]`, where `N` is the number of predictions, and are probability maps between 0-1.

In [None]:
Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())

Looks pretty good!

## Wrapping up

In this tutorial, you have learned how to create your own training pipeline for instance segmentation models, on a custom dataset.
For that, you wrote a `torch.utils.data.Dataset` class that returns the images and the ground truth boxes and segmentation masks. You also leveraged a Mask R-CNN model pre-trained on COCO train2017 in order to perform transfer learning on this new dataset.

For a more complete example, which includes multi-machine / multi-gpu training, check `references/detection/train.py`, which is present in the [torchvision GitHub repo](https://github.com/pytorch/vision/tree/v0.3.0/references/detection). 

