In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
class Mish(nn.Module):
    def __init__(self):
        super(Mish, self).__init__()
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))
        

In [18]:
# a = Mish()
# a(torch.Tensor([5])) # tensor([4.9996])

# import math
# 5 * math.tanh(math.log(1+math.exp(5))) # 4.999552077529406

In [26]:
class BasicConv(nn.Module):
    
    def __init__(self, in_channel, out_channel, kernel_size, stride=1):
        super(BasicConv, self).__init__()
        
        self.conv = nn.Conv2d(in_channel, out_channel,  kernel_size, stride, kernel_size//2, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        self.activation = Mish()
        
    def forward(self, x):
        
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        
        return x
        

In [27]:
# a = BasicConv(10,20,3)
# print(a)
# # BasicConv(
# #   (conv): Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# #   (bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# #   (activation): Mish()
# # )

In [28]:
class Resblock(nn.Module):
    def __init__(self, channels, hidden_channels=None, residual_activation=nn.Identity()):
        super(Resblock, self).__init__()
        
        if not hidden_channels:
            hidden_channels = channels
        self.block = nn.Sequential(
            BasicConv(channels, hidden_channels, 1),
            BasicConv(hidden_channels, channels, 3)
        )
    def forward(self, x):
        return x + self.block(x)

In [30]:
# a = Resblock(10,20,3)
# print(a)
# # Resblock(
# #   (block): Sequential(
# #     (0): BasicConv(
# #       (conv): Conv2d(10, 20, kernel_size=(1, 1), stride=(1, 1), bias=False)
# #       (bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# #       (activation): Mish()
# #     )
# #     (1): BasicConv(
# #       (conv): Conv2d(20, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# #       (bn): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# #       (activation): Mish()
# #     )
# #   )
# # )

In [34]:
class Resblock_body(nn.Module):
    def __init__(self, in_channels, out_channels, num_blocks, first):
        super(Resblock_body, self).__init__()
        
        self.downsample_conv = BasicConv(in_channels, out_channels, 3, stride=2)
        
        if first:
            self.split_conv0 = BasicConv(out_channels, out_channels, 1)
            self.split_conv1 = BasicConv(out_channels, out_channels, 1)
            self.blocks_conv = nn.Sequential(
                Resblock(out_channels, out_channels//2),
                BasicConv(out_channels, out_channels, 1)
            )
            self.concat_conv = BasicConv(out_channels*2, out_channels, 1)
        else:
            self.split_conv0 = BasicConv(out_channels, out_channels//2, 1)
            self.split_conv1 = BasicConv(out_channels, out_channels//2, 1)
            self.blocks_conv = nn.Sequential(
                *[Resblock(out_channels//2) for _ in range(num_blocks)],
                BasicConv(out_channels//2, out_channels//2,1)
            )
            self.concat_conv = BasicConv(out_channels, out_channels, 1)
    def forward(self, x):
        x = self.downsample_conv(x)
        
        x0 =self.split_conv0(x)
        
        x1 = self.split_conv1(x)
        x1 = self.blocks_conv(x1)
        
        x = torch.cat([x0, x1], dim=1)
        x = self.concat_conv(x)
        return x

In [38]:
# a = Resblock_body(10,20,3, False)
# print(a)

In [41]:
class CSPDarkNet(nn.Module):
    def __init__(self, layers):
        super(CSPDarkNet, self).__init__()
        self.inplane = 32
        self.conv1 = BasicConv(3, self.inplane, 3, 1)
        self.feature_channels = [64, 128, 256, 512, 1024]
        
        self.stage = nn.ModuleList([
            Resblock_body(self.inplane, self.feature_channels[0], layers[0], True),
            Resblock_body(self.feature_channels[0], self.feature_channels[1], layers[1], False),
            Resblock_body(self.feature_channels[1], self.feature_channels[2], layers[1], False),
            Resblock_body(self.feature_channels[2], self.feature_channels[3], layers[1], False),
            Resblock_body(self.feature_channels[3], self.feature_channels[4], layers[1], False),
        ])
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0,math.sqrt(2. / n))
            if isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                
    def forward(self, x):
        x = self.conv1(x)
        
        x = self.stage[0](x)
        x = self.stage[1](x)
        out3 = self.stage[2](x)
        out4 = self.stage[3](x)
        out5 = self.stage[4](x)
        
        return out3, out4, out5
        

In [43]:
# a = CSPDarkNet([1,1,1,1,2])
# print(a)

In [46]:
def darknet53(pretrained, **kwargs):
    model = CSPDarkNet([1, 2, 8, 8, 4])
    if pretrained:
        if isinstance(pretrained, str):
            model.load_state_dict(model.load(pretrained))
        else:
            raise Exception('darknet request a pretrained path. got [{}]".format(pretrained)')
    return model
        

In [47]:
a = darknet53(False)
print(a)

CSPDarkNet(
  (conv1): BasicConv(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): Mish()
  )
  (stage): ModuleList(
    (0): Resblock_body(
      (downsample_conv): BasicConv(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): Mish()
      )
      (split_conv0): BasicConv(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): Mish()
      )
      (split_conv1): BasicConv(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     