<a href="https://colab.research.google.com/github/taravatp/Panopic-Feature-Pyramid-Network/blob/main/models/semantic_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import time

import torch
import torch.nn as nn
from torch.nn.modules.activation import ReLU
from torch.nn.modules.normalization import GroupNorm
import torchvision

# from torchvision.models.feature_extraction import create_feature_extractor

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:',device)

In [None]:
class SemanticSegmentationModel(nn.Module):
  def __init__(self,num_classes):

    super(SemanticSegmentationModel,self).__init__()
    
    self.num_classes = num_classes
    self.backbone = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).backbone.to(device)

    #return_nodes = ['fpn.layer_blocks.0','fpn.layer_blocks.1','fpn.layer_blocks.2','fpn.layer_blocks.3']
    #self.feature_extractor = create_feature_extractor(self.backbone,return_nodes)

    self.upsample1 =  nn.Sequential(
        nn.Conv2d(256,128,kernel_size=3,stride=1,padding=1), #does not change the shape
        nn.GroupNorm(128,128),
        nn.ReLU(128),
        nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
    )

    self.upsample2 =  nn.Sequential(
      nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
      nn.GroupNorm(128,128),
      nn.ReLU(128),
      nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
    )

    self.conv1 = nn.Conv2d(256,128,kernel_size=3,stride=1,padding=1)

    self.last_layer = nn.Sequential(
        nn.Conv2d(128,self.num_classes,kernel_size=1,device=device),
        nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
    )
  
  def forward(self,x):

    x = self.backbone(x) #output of backbone - it is a dictionary
    
    #fetchin pyramid levels
    p0 = x['0'] #(batchsize- 256 - 1/4H - 1/4W)
    p1 = x['1'] #(batchsize- 256 - 1/8H - 1/8W)
    p2 = x['2'] #(batchsize- 256 - 1/16H - 1/16W)
    p3 = x['3'] ##(batchsize- 256 - 1/32H - 1/32W)

    p3 = self.upsample1(p3)
    p3 = self.upsample2(p3)
    p3 = self.upsample2(p3)

    p2 = self.upsample1(p2)
    p2 = self.upsample2(p2)

    p1 = self.upsample1(p1)

    p0 = self.conv1(p0)

    p0 = p0 + p1 + p2 + p3 
    
    #producing the final result
    p0 = self.last_layer(p0) 

    return p0

In [None]:
if __name__ == '__main__':
  
  #get output from the model
  test = SemanticSegmentationModel(5).to(device)
  test_input = torch.ones((2, 3, 640, 480)).to(device)
  output = test(test_input)
  print(output.shape)

torch.Size([2, 5, 640, 480])
