In [10]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetBackbone, self).__init__()
        resnet = models.resnet101(pretrained=pretrained)
        self.stage1 = nn.Sequential(*list(resnet.children())[:4])  # res2
        self.stage2 = nn.Sequential(*list(resnet.children())[4])   # res3
        self.stage3 = nn.Sequential(*list(resnet.children())[5])   # res4
        self.stage4 = nn.Sequential(*list(resnet.children())[6])   # res5
    
    def forward(self, x):
        res2 = self.stage1(x)
        res3 = self.stage2(res2)
        res4 = self.stage3(res3)
        res5 = self.stage4(res4)
        return res2, res3, res4, res5

class SemanticFPN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SemanticFPN, self).__init__()
        self.lateral4 = nn.Conv2d(in_channels[3], out_channels, 1)
        self.lateral3 = nn.Conv2d(in_channels[2], out_channels, 1)
        self.lateral2 = nn.Conv2d(in_channels[1], out_channels, 1)
        self.lateral1 = nn.Conv2d(in_channels[0], out_channels, 1)
        
        self.fpn_out4 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.fpn_out3 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.fpn_out2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.fpn_out1 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
    def forward(self, res2, res3, res4, res5):
        # Lateral connections
        lat4 = self.lateral4(res5)
        lat3 = self.lateral3(res4)
        lat2 = self.lateral2(res3)
        lat1 = self.lateral1(res2)
        
        # Top-down pathway
        fpn4 = self.fpn_out4(lat4)
        fpn3 = self.fpn_out3(lat3 + F.interpolate(fpn4, size=lat3.shape[-2:], mode='bilinear', align_corners=False))
        fpn2 = self.fpn_out2(lat2 + F.interpolate(fpn3, size=lat2.shape[-2:], mode='bilinear', align_corners=False))
        fpn1 = self.fpn_out1(lat1 + F.interpolate(fpn2, size=lat1.shape[-2:], mode='bilinear', align_corners=False))
        
        return fpn1

class PointHead(nn.Module):
    def __init__(self, in_c=256 * 2, num_classes=19, k=3, beta=0.75):
        super().__init__()
        self.mlp = nn.Conv1d(in_c, num_classes, 1)
        self.k = k
        self.beta = beta

    def forward(self, x, res2, out):
        if not self.training:
            return self.inference(x, res2, out)

        points = self.sampling_points(out, x.shape[-1] // 16, self.k, self.beta)

        coarse = self.point_sample(out, points, align_corners=False)
        fine = self.point_sample(res2, points, align_corners=False)

        feature_representation = torch.cat([coarse, fine], dim=1)

        rend = self.mlp(feature_representation)

        return {"rend": rend, "points": points}

    @torch.no_grad()
    def inference(self, x, res2, out):
        num_points = 8192

        while out.shape[-1] != x.shape[-1]:
            out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)

            points_idx, points = self.sampling_points(out, num_points, training=self.training)

            coarse = self.point_sample(out, points, align_corners=False)
            fine = this.point_sample(res2, points, align_corners=False)

            feature_representation = torch.cat([coarse, fine], dim=1)

            rend = self.mlp(feature_representation)

            B, C, H, W = out.shape
            points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
            out = (out.reshape(B, C, -1).scatter_(2, points_idx, rend).view(B, C, H, W))

        return {"fine": out}
    
    def sampling_points(self, mask, N, k=3, beta=0.75):
        # Simplified version for demonstration purposes
        return torch.rand(mask.size(0), N, 2, device=mask.device, dtype=mask.dtype)

    def point_sample(self, feature_map, point_coords, align_corners=False):
        # Simplified version for demonstration purposes
        return F.grid_sample(feature_map, point_coords.unsqueeze(2), align_corners=align_corners).squeeze(2)

class SemanticFPNPointRend(nn.Module):
    def __init__(self, backbone, fpn, point_head):
        super(SemanticFPNPointRend, self).__init__()
        self.backbone = backbone
        self.fpn = fpn
        self.point_head = point_head
    
    def forward(self, x):
        res2, res3, res4, res5 = self.backbone(x)
        coarse = self.fpn(res2, res3, res4, res5)
        output = self.point_head(x, res2, coarse)
        return output


In [11]:
# Initialize components
resnet_backbone = ResNetBackbone(pretrained=True)
fpn = SemanticFPN(in_channels=[256, 512, 1024, 2048], out_channels=256)
point_head = PointHead(in_c=256 * 2, num_classes=19)  # Adjust input channels and number of classes as needed

# Combine into a single model
pointrend_model = SemanticFPNPointRend(resnet_backbone, fpn, point_head)

# Example usage
input_tensor = torch.rand(1, 3, 512, 512)  # Example input tensor
output = pointrend_model(input_tensor)
print(output)

RuntimeError: Given groups=1, weight of size [256, 2048, 1, 1], expected input[1, 1024, 32, 32] to have 2048 channels, but got 1024 channels instead

In [2]:
import detectron2
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer
from detectron2.data import MetadataCatalog
import os
from detectron2.projects.point_rend import ColorAugSSDTransform, add_pointrend_config

# Load the configuration
cfg = get_cfg()
add_pointrend_config(cfg)
cfg.merge_from_file("./model/configs/semFPN_pointrend.yaml")

# Set the weights and number of classes
cfg.MODEL.WEIGHTS = "detectron2://ImageNetPretrained/MSRA/R-101.pkl"
cfg.MODEL.RESNETS.DEPTH = 101
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 19
cfg.MODEL.POINT_HEAD.NUM_CLASSES = 19
cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS = 2048
cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS = 8192

# Dataset configuration
cfg.DATASETS.TRAIN = ("cityscapes_fine_sem_seg_train",)
cfg.DATASETS.TEST = ("cityscapes_fine_sem_seg_val",)

# Solver configuration
cfg.SOLVER.BASE_LR = 0.01
cfg.SOLVER.STEPS = (40000, 55000)
cfg.SOLVER.MAX_ITER = 65000
cfg.SOLVER.IMS_PER_BATCH = 32

# Input configuration
cfg.INPUT.MIN_SIZE_TRAIN = (512, 768, 1024, 1280, 1536, 1792, 2048)
cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
cfg.INPUT.MIN_SIZE_TEST = 1024
cfg.INPUT.MAX_SIZE_TRAIN = 4096
cfg.INPUT.MAX_SIZE_TEST = 2048
cfg.INPUT.CROP.ENABLED = True
cfg.INPUT.CROP.TYPE = "absolute"
cfg.INPUT.CROP.SIZE = (512, 1024)
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 0.75
cfg.INPUT.COLOR_AUG_SSD = True

# Dataloader configuration
cfg.DATALOADER.NUM_WORKERS = 10

# Output directory
cfg.OUTPUT_DIR = "./output"

# Make sure the output directory exists
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

# Train the model
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

AssertionError: Torch not compiled with CUDA enabled