In [None]:
# default_exp models.maskrcnn

# MaskRCNN

> API - details...

In [None]:
#hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

import warnings
warnings.filterwarnings("ignore")

In [None]:
#export

from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection import MaskRCNN
from torchvision.ops.misc import FrozenBatchNorm2d
from functools import partial

In [None]:
#export 

_model_urls = {
    'maskrcnn_resnet50_fpn_coco':
        'https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth',
}

In [None]:
# export

def get_maskrcnn_model(arch_str, num_classes, pretrained=False, pretrained_backbone=True,
                 trainable_layers=5, **kwargs):
    
    #if pretrained: pretrained_backbone = False
        
    backbone = resnet_fpn_backbone(arch_str, pretrained=pretrained_backbone, trainable_layers=trainable_layers)
    model = MaskRCNN(backbone, 
                     num_classes, 
                     image_mean = [0.0, 0.0, 0.0], # already normalized by fastai
                     image_std = [1.0, 1.0, 1.0],
                     **kwargs)
    
    if pretrained:
        try:
            
            pretrained_dict = load_state_dict_from_url(_model_urls['maskrcnn_'+arch_str+'_fpn_coco'],
                                                       progress=True)
            model_dict = model.state_dict()
            
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if
                       (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)}
                     
            model_dict.update(pretrained_dict) 
            model.load_state_dict(model_dict)
            
            for module in model.modules():
                if isinstance(module, FrozenBatchNorm2d):
                    module.eps = 0.0
                    
        except Exception as e: 
            #print(e)            
            print("No pretrained coco model found for maskrcnn_"+arch_str)
            print("This does not affect the backbone.")
            
    return model

In [None]:
#export

maskrcnn_resnet18 = partial(get_maskrcnn_model, arch_str="resnet18")
maskrcnn_resnet34 = partial(get_maskrcnn_model, arch_str="resnet34")
maskrcnn_resnet50 = partial(get_maskrcnn_model, arch_str="resnet50")
maskrcnn_resnet101 = partial(get_maskrcnn_model, arch_str="resnet101")
maskrcnn_resnet152 = partial(get_maskrcnn_model, arch_str="resnet152")

In [None]:
maskrcnn_resnet50(num_classes=4)

MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
