In [1]:
import torch
from torch import nn

In [2]:
torch.__version__

'1.12.1+cu102'

In [3]:
class ConvLayers(nn.Module):
    """ 
        Conv layers with batchnorm and activation layers
    """
    def __init__(self, input_ch, output_ch, kernel_size, stride, padding=1):
        super(ConvLayers, self).__init__()
        if kernel_size == 1:
            self.conv_layer = nn.Conv2d(input_ch, output_ch, kernel_size=kernel_size, stride=stride)
        else:
            self.conv_layer = nn.Conv2d(input_ch, output_ch, kernel_size=kernel_size, stride=stride, padding=padding)
        self.norm = nn.BatchNorm2d(output_ch)
        self.act = nn.GELU()

    def forward(self, x):
        return self.act(self.norm(self.conv_layer(x)))


class ConvLayers_1(nn.Module):
    """ 
        Conv layers with batchnorm and activation layers
    """
    def __init__(self, input_ch, output_ch, kernel_size, stride):
        super(ConvLayers_1, self).__init__()
        if stride == 2:
            self.conv_layer = nn.Conv2d(input_ch, output_ch, kernel_size=kernel_size, stride=stride)
        else:
            self.conv_layer = nn.Conv2d(input_ch, output_ch, kernel_size=kernel_size, stride=stride, padding="same")
        self.norm = nn.BatchNorm2d(output_ch)
        self.act = nn.GELU()

    def forward(self, x):
        return self.act(self.norm(self.conv_layer(x)))


class ConvBlock(nn.Module):
    """ Conv block with residual connections 
    """
    def __init__(self, list_conv_layers, num_repeat):
        super(ConvBlock, self).__init__()
        self.repeat = num_repeat
        temp_conv_layers = []
        for layer in list_conv_layers:
            temp_conv_layers.append(ConvLayers(layer[0], layer[1], layer[2], layer[3]))
        self.conv_layers = nn.Sequential(*temp_conv_layers)

    def forward(self, x):
        x_ = x
        for _ in range(self.repeat):
            x = self.conv_layers(x)
        return x + x_


class DarkNetHead(nn.Module):

    def __init__(self, model_parameters):
        super(DarkNetHead, self).__init__()
        temp_layers = []
        for layers in model_parameters:
            if layers[0] == "cl":
                temp_layers.append(ConvLayers_1(layers[1], layers[2], layers[3], layers[4]))
            elif layers[0] == "cb":
                temp_layers.append(ConvBlock(layers[1], layers[2]))
        self.conv_layers = nn.Sequential(*temp_layers)

    def forward(self, x):
        return self.conv_layers(x)




model_parameters = [
    ("cl", 3, 32, 3, 1),
    ("cl", 32, 64, 3, 2),
    ("cb", [(64, 32, 1, 1), (32, 64, 3, 1)], 1),
    ("cl", 64, 128, 3, 2),
    ("cb", [(128, 64, 1, 1), (64, 128, 3, 1)], 2),
    ("cl", 128, 256, 3, 2),
    ("cb", [(256, 128, 1, 1), (128, 256, 3, 1)], 8),
    ("cl", 256, 512, 3, 2),
    ("cb", [(512, 256, 1, 1), (256, 512, 3, 1)], 8),
    ("cl", 512, 1024, 3, 2),
    ("cb", [(1024, 512, 1, 1), (512, 1024, 3, 1)], 4)
]



In [4]:
conv1 = ConvLayers_1(64, 32, 1, 1)
conv2 = ConvLayers_1(32, 64, 7, 1)
inp = torch.rand((4, 64, 256, 256))
inp1 = conv1(inp)
print(inp1.shape)
inp2 = conv2(inp1)
print(inp2.shape)

torch.Size([4, 32, 256, 256])
torch.Size([4, 64, 256, 256])


In [5]:
inp = torch.rand((4, 3, 256, 256))

In [6]:
model = DarkNetHead(model_parameters)

In [7]:
out = model(inp)
out.shape

torch.Size([4, 1024, 7, 7])

In [8]:
class ConvSame(nn.Module):
    
    def __init__(self, in_ch, out_ch, k, s, norm=True, act=True):
        super(ConvSame, self).__init__()
        p = k//2
        self.conv = nn.Conv2d(in_ch, out_ch, k, s, p)
        if norm:
            self.norm = nn.BatchNorm2d(out_ch)
        else:
            self.norm = nn.Identity()
        if act:
            self.act = nn.GELU()
        else:
            self.act = nn.Identity()
        
    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

class ConvPad(nn.Module):
    
    def __init__(self, in_ch, out_ch, k, s, p=1, norm=True, act=True):
        super(ConvPad, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, s, p)
        if norm:
            self.norm = nn.BatchNorm2d(out_ch)
        else:
            self.norm = nn.Identity()
        if act:
            self.act = nn.GELU()
        else:
            self.act = nn.Identity()
        
    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

class ConvBlock(nn.Module):
    
    def __init__(self, in_ch):
        super(ConvBlock, self).__init__()
        out_ch = in_ch // 2
        self.conv1 = ConvSame(in_ch, out_ch, k=1, s=1, norm=False)
        self.conv2 = ConvSame(out_ch, in_ch, k=3, s=1)
    
    def forward(self, x):
        x_ = x
        x = self.conv2(self.conv1(x))
        return x + x_
    

class RepeatBlock(nn.Module):
    
    def __init__(self, in_ch, repeat):
        super(RepeatBlock, self).__init__()
        self.repeat = repeat
        self.conv = ConvBlock(in_ch)
        
    def forward(self, x):
        for _ in range(self.repeat):
            x = self.conv(x)
        return x
    
class HeadBlock1(nn.Module):
    
    def __init__(self,  in_ch=1024, out_ch=512):
        super(HeadBlock1, self).__init__()
        out_ = out_ch * 2
        self.conv1 = ConvSame(in_ch, out_ch, 1, 1)
        self.conv2 = ConvSame(out_ch, out_, 3, 1)
        self.conv3 = ConvSame(out_, out_ch, 1, 1)
        self.conv4 = ConvSame(out_ch, out_, 3, 1)
        
    def forward(self, x):
        return self.conv4(self.conv3(self.conv2(self.conv1(x))))
    

    
class UpSample(nn.Module):
    
    def __init__(self,):
        super(UpSample, self).__init__()
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        
    def forward(self, x):
        return self.up(x)
    
    
class BackBone(nn.Module):
    
    def __init__(self):
        super(BackBone, self).__init__()
        
        self.conv1 = ConvSame(3, 32, 3, 1)
        self.conv2 = ConvPad(32, 64, 3, 2)
        
        self.conv3 = RepeatBlock(64, 1)
        self.conv4 = ConvPad(64, 128, 3, 2)
        
        self.conv5 = RepeatBlock(128, 2)
        self.conv6 = ConvPad(128, 256, 3, 2)
        
        self.conv7 = RepeatBlock(256, 8)
        self.conv8 = ConvPad(256, 512, 3, 2)
        
        self.conv9 = RepeatBlock(512, 8)
        self.conv10 = ConvPad(512, 1024, 3, 2)
        
        self.conv11 = RepeatBlock(1024, 4)
        self.conv12 = HeadBlock1(1024, 512)
    
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x1 = self.conv7(x)
        
        x = self.conv8(x1)
        x2 = self.conv9(x)
        
        x = self.conv10(x2)
        x = self.conv11(x)
        x = self.conv12(x)
    

        return x1, x2, x
    
    
class OutputBlock1(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(OutputBlock1, self).__init__()
        self.conv = ConvSame(in_ch, out_ch, 1, 1)
        
    def forward(self, x):
        return self.conv(x)
    
    
class OutputBlock2(nn.Module):
    
    def __init__(self, in_ch, out_ch, next_inp, next_out):
        super(OutputBlock2, self).__init__()
        self.conv1 = ConvSame(in_ch, out_ch, 1, 1)
        self.upsample = UpSample()
        self.conv2 = HeadBlock1(768, 256)
        self.conv3 = ConvSame(next_inp, next_out, 1, 1)
        
    def forward(self, x1, x2):
        #print(x1.shape, x2.shape)
        x1 = self.conv1(x1)
        x1 = self.upsample(x1)
        x1 = torch.cat([x1, x2], dim=1)
        #print(x1.shape)
        x2 = self.conv2(x1)
        #print(x2.shape)
        x1 = self.conv3(x2)
        return x1, x2
        
   

class OutputBlock3(nn.Module):
    
    def __init__(self, in_ch, out_ch, next_inp, next_out):
        super(OutputBlock3, self).__init__()
        self.conv1 = ConvSame(in_ch, out_ch, 1, 1)
        self.upsample = UpSample()
        self.conv2 = HeadBlock1(384, 128)
        self.conv3 = ConvSame(next_inp, next_out, 1, 1)
        
    def forward(self, x1, x2):
        x1 = self.conv1(x1)
        x1 = self.upsample(x1)
        x1 = torch.cat([x1, x2], dim=1)
        x1 = self.conv2(x1)
        x1 = self.conv3(x1)
        return x1
    
    
class YoloModel(nn.Module):
    
    def __init__(self):
        super(YoloModel, self).__init__()
        self.backbone = BackBone()
        self.block1 = OutputBlock1(1024, 255)
        self.block2 = OutputBlock2(1024, 256, 512, 255)
        self.block3 = OutputBlock3(512, 128, 256, 255)
        
    def forward(self, x):
        x1, x2, x3 = self.backbone(x)
        x = self.block1(x3)
        x2, x3 = self.block2(x3, x2)
        x1 = self.block3(x3, x1)
        return x, x2, x1



In [9]:
m = YoloModel()

In [10]:
inp = torch.rand((4,3,500,500))

In [11]:
o = m(inp)
o[0].shape, o[1].shape, o[2].shape

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 64 but got size 63 for tensor number 1 in the list.

In [None]:
num_param = sum(p.numel() for p in m.parameters() if p.requires_grad)

In [None]:
num_param