## BACKBONE

In [1]:
import torch
import torch.nn as nn
import torchvision.models as models

class ResNetBackbone(nn.Module):
    def __init__(self):
        super().__init__()

        resnet = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        self.out_channels = 2048

    def forward(self, x):
        x = self.backbone(x)
        return x

## REIGON PROPOSAL NETWORK

In [2]:
import torch.nn.functional as F

class RPN(nn.Module):
    def __init__(self, in_channels, num_anchors=9):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.obj_logits  = nn.Conv2d(in_channels, num_anchors, kernel_size=1)
        self.bbox = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1)

    def forward(self, x):
        x = F.relu(self.conv(x))
        obj_logits  = self.obj_logits(x)
        bbox = self.bbox(x)
        return obj_logits , bbox

## ANCHOR GENERATOR

In [3]:
def generate_anchors(feature_map_size,scales,ratio,stride=16):
    anchors = []

    for i in range(feature_map_size):
        for j in range(feature_map_size):
            x = j*stride
            y = i*stride
            for scale in scales:
                for ratio in ratio:
                    w = scale * ratio
                    h = scale / ratio

                    anchors.append((x, y, x+w, y+h))

    return torch.tensor(anchors,dtype=torch.float32)

## PROPOSAL GENERATION

In [4]:
from torchvision.ops import nms

def generate_proposals(anchors,bbox,scores):
    proposals = anchors + bbox
    scores = scores.flatten()
    proposals = proposals.view(-1,4)
    keep = nms(proposals,scores,0.7)
    return proposals[keep][:300]

## ROI ALIGN

In [5]:
from torchvision.ops import roi_align

def roi_pooling(feature_map, proposals):
    return roi_align(feature_map,[proposals],output_size=(7,7))

## DETECTION HEAD

In [6]:
class DetectionHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.fc1 = nn.Linear(in_channels*7*7,1024)
        self.fc2 = nn.Linear(1024,1024)
        self.cls_logits = nn.Linear(1024,num_classes)
        self.bbox = nn.Linear(1024,4*num_classes)

    def forward(self, x):
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        cls_logits = self.cls_logits(x)
        bbox = self.bbox(x)
        return cls_logits,bbox

## COMPLETE MODEL

In [7]:
class FasterRCNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.backbone = ResNetBackbone()
        self.rpn = RPN(self.backbone.out_channels)
        self.head = DetectionHead(self.backbone.out_channels,num_classes=14)
    
    def forward(self,images,target=None):
        feature_map = self.backbone(images)
        rpn_obj_logits,rpn_bbox = self.rpn(feature_map)
        anchors = generate_anchors(feature_map.shape[-1],[64,128,256],ratio=[0.5,1,2])
        proposals = generate_proposals(anchors,rpn_bbox,rpn_obj_logits)
        roi = roi_pooling(feature_map,proposals)
        cls_logits,bbox = self.head(roi)
        return cls_logits,bbox

## LOSS FUNCTIONS

In [8]:
def rpnloss(pred_obj,pred_box,get_obj,get_box):
    obj_loss = F.binary_cross_entropy_with_logits(pred_obj,get_obj)
    reg_loss = F.smooth_l1_loss(pred_box,get_box)

    return obj_loss + reg_loss

def detectionloss(pred_cls,pred_box,get_cls,get_box): 
    cls_loss = F.cross_entropy(pred_cls,get_cls)
    reg_loss = F.smooth_l1_loss(pred_box,get_box)

    return cls_loss + reg_loss

In [13]:
model = FasterRCNN()
optimizer = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

for images,targets in train_dataloader:
    optimizer.zero_grad()
    cls_logits,bbox = model(images)
    loss = detectionloss(cls_logits,bbox,targets['labels'],targets['boxes'])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()



NameError: name 'train_dataloader' is not defined

In [14]:
def inference(model,images):
    with torch.no_grad:
        cls_logits,bbox = model(images)
        return cls_logits,bbox