<a href="https://colab.research.google.com/github/taravatp/Panopic-Feature-Pyramid-Network/blob/main/models/panoptic_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.nn.modules.activation import ReLU
from torch.nn.modules.normalization import GroupNorm

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.feature_extraction import create_feature_extractor

import time
import numpy as np

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:',device)

device: cpu


In [None]:
class PanopticSegmentationModel(nn.Module):
  def __init__(self,num_things_classes,num_stuff_classes,flag='train'):

    super(PanopticSegmentationModel,self).__init__()
    self.num_things_classes = num_things_classes
    self.num_stuff_classes = num_stuff_classes
    self.instance_model = self.get_model_instance_segmentation()
    self.flag = flag

    self.upsample1 =  nn.Sequential(
        nn.Conv2d(256,128,kernel_size=3,stride=1,padding=1),
        nn.GroupNorm(128,128),
        nn.ReLU(128),
        nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
    )

    self.upsample2 =  nn.Sequential(
      nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
      nn.GroupNorm(128,128),
      nn.ReLU(128),
      nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
    )

    self.conv1 = nn.Conv2d(256,128,kernel_size=3,stride=1,padding=1)

    self.last_layer = nn.Sequential(
        nn.Conv2d(128,self.num_stuff_classes,kernel_size=1,device=device),
        nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
    )
    
  def get_model_instance_segmentation(self):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features,self.num_things_classes)
    in_features_mask =  model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,hidden_layer,self.num_things_classes)
    return model

  def get_FPN_levels(self,model,x):
    backbone = model.backbone
    return_nodes = ['fpn.layer_blocks.0','fpn.layer_blocks.1','fpn.layer_blocks.2','fpn.layer_blocks.3']
    feature_extractor = create_feature_extractor(backbone,return_nodes)
    with torch.no_grad():
      feature_pyramid_levels = feature_extractor(x)

    p3 = feature_pyramid_levels['fpn.layer_blocks.3']
    p2 = feature_pyramid_levels['fpn.layer_blocks.2']
    p1 = feature_pyramid_levels['fpn.layer_blocks.1']
    p0 = feature_pyramid_levels['fpn.layer_blocks.0']
    return p3,p2,p1,p0

  def forward(self,x,targets):
    
    p3,p2,p1,p0 = self.get_FPN_levels(self.instance_model,x)

    #running the instance segmentation model
    if self.flag == 'train':
      self.instance_model.train()
      instance_output = self.instance_model(x,targets) #produces loss values of instance segmentation layers!
    else: #validation
      self.instance_model.eval()
      instance_output = self.instance_model(x) #predictions
    
    #running the semantic segmentation branch
    p3 = self.upsample1(p3)
    p3 = self.upsample2(p3)
    p3 = self.upsample2(p3)

    p2 = self.upsample1(p2)
    p2 = self.upsample2(p2)

    p1 = self.upsample1(p1)

    p0 = self.conv1(p0)
    
    p0 = p0 + p1 + p2 + p3 
    
    #producing the final result
    p0 = self.last_layer(p0)  #semantic segmentation prediction
    #print(self.instance_model)

    return instance_output,p0