In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import yaml
import torch.onnx

In [19]:
def autopad(kernel, padding=None):
    # same 卷积（让卷积过后大小不变）
    # 比如3x3的卷积核，要大小不变，padding大小为1
    # 计算方式是 3 // 2
    if padding is None:
        # 看kernel是整数还是数组
        # 整数直接操作，数组逐个操作
        padding = kernel // 2 if isinstance(kernel, int) else [x // 2 for x in kernel]
    return padding

class Conv(nn.Module):
    '''
    CBL: conv, bn, leakReLU
    '''
    def __init__(self, in_channel, out_channel, kernel_size=1, stride=1, \
                 padding=None, groups=1, activation=True):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, \
                             stride, autopad(kernel_size, padding), groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        self.act = nn.LeakyReLU(0.1, inplace=True) if activation else nn.Identity()
        
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))
    
    def fuse_forward(self, x):
        return self.act(self.conv(x))


class Bottleneck(nn.Module):
    '''
    Res unit: x + (CBL + CBL)
    '''
    def __init__(self, in_channel, out_channel, shortcut=True, groups=1, expansion=0.5):
        super().__init__()
        
        hidden_channel = int(out_channel * expansion)
        self.cv1 = Conv(in_channel, hidden_channel, 1, 1)
        self.cv2 = Conv(hidden_channel, out_channel, 3, 1, groups=groups)
        self.add = shortcut and in_channel == out_channel # in==out时才可能add
        
    def forward(self, x):
        y = self.cv2(self.cv1(x))
        if self.add:
            y = x + y
        return y
    

    
class BottleneckCSP(nn.Module):
    '''
    Cross Stage Partial Networks
    CSP1_x: 
    y1: (CBL + n*Res + Conv)
    y2: (Conv)
    y: concat(y1, y2)
    => 
    y + BN + LeakReLU + CBL
    '''
    def __init__(self, in_channel, out_channel, repeats=1, shortcut=True, expansion=0.5):
        super().__init__()
        
        hidden_channel = int(out_channel * expansion)
        self.cv1 = Conv(in_channel, hidden_channel, 1, 1)
        self.cv2 = nn.Conv2d(in_channel, hiden_channel, 1, 1, bias=False)
        self.cv3 = nn.Conv2d(hidden_channel, hiden_channel, 1, 1, bias=False)
        self.cv4 = Conv(2 * hidden_channel, out_channel, 1, 1)
        self.bn = nn.BatchNorm2d(2 * hidden_channel)
        self.act = nn.LeakyReLU(0.1, inplace=True)
        self.m = nn.Sequential(*[Bottleneck( \
            hidden_channel, hidden_channel, shortcut, groups, expansion=1.0) \
            for _ in range(repeats)])
        
    def forward(self, x):
        y1 = self.cv3(self.m(self.cv1(x)))  # 表达式从右往左读，图形是(cbl+n*res+conv)
        y2 = self.cv2(x)
        return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
    
class SPP(nn.Module):
    '''
    Spatial pyramid pooling layer
    SPP: 
    a: CBL
    b: (a -> Maxpool) * 3
    c: concat(a, b)
    c -> CBL
    '''
    def __init__(self, in_channel, out_channel, kernel_size_list=(5, 9, 13)):
        super().__init__()
        
        hedden_channel = in_channel // 2
        self.cv1 = Conv(in_channel, hidden_channel, 1, 1)
        self.cv2 = Conv(hidden_channel * (len(kernel_size_list) + 1), out_channel, 1, 1)
        self.m = nn.ModuleList([ \
            nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size//2) \
            for kernel_size in kernel_size_list \
        ])
        
    def forward(self, x):
        a = self.cv1(x)
        return self.cv2(torch.cat([a] + [m(a) for m in self.m], dim=1))
    
# Focus, 无损下采样2倍
class Focus(nn.Module):
    '''
    concat(slice * 4) -> CBL(conv, bn, leakyrelu)
    '''
    def __init__(self, in_channel, out_channel, kernel_size=1, \
                stride=1, padding=None, groups=1, activation=True):
        super().__init__()
        
        self.conv = Conv(in_channel * 4, out_channel, kernel_size, stride, padding, groups, activation)
    
    def forward(self, x):
        # 4 slice: BCHW, ::2即step为2，每隔1个像素点取1个像素，存到一个通道里去
        # 因为隔了一个个像素，起点就有4种组合，即(0, 0), (1, 0), (0, 1), (1, 1)
        slice0 = x[..., ::2, ::2]
        slice1 = x[..., 1::2, ::2]
        slice2 = x[..., ::2, 1::2]
        slice3 = x[..., 1::2, 1::2]
        y = torch.cat([slice0, slice1, slice2, slice3], dim=1)
        return self.conv(y)
    
class Concat(nn.Module):
    def __init__(self, dimension=1):
        self.d = dimension
        
    def forward(self, x):
        return torch.cat(x, dim=self.d)
    
class Detect(nn.Module):
    '''
    三个255通道的head
    (5 + 80) * 3 = 255
    5: (cx, cy, w, h), is_object
    80 :classes
    '''
    def __init__(self, num_classes, num_anchor, reference_channels):
        super().__init__()
        
        self.num_anchor = num_anchor
        self.num_classes = num_classes
        self.num_output = reference_channels + 5
        self.m = nn.ModuleList( \
                nn.Conv2d(input_channel, self.num_output * self.num_anchor, 1) \
                for input_channel in reference_channels)
        self.init_weight()
        
    def forward(self, x):
        # 把x的每个元素都拿去卷一下然后替换x
        for ilevel, module in enumerate(self.m):
            x[ilevel] = module(x[ilevel])
        return x
    
    def init_weight(self):
        strides = [8, 16, 32]
        # 三个尺度各卷一次
        for head, stride in zip(self.m, stride):
            bias = head.bias.view(self.num_anchor, -1)  # reshape (3, (5+80))
            # cx, cy, w, h, objectness, (80)classification
            # objectness = log(...)
            # prob = sigmoid(objectness) = 1 / (1 +e^-z)
            # loss = BCE(prob - target)
            bias[:, 4] += math.log(8 / (640 / stride) ** 2)          # objectness
            bias[:, 5:] += math.log(0.6 / (self.num_classes - 0.99)) # classification
            head.bias = nn.Parameter(bias.biew(-1), requires_grad=True)

In [27]:
class Yolo(nn.Module):
    def __init__(self, num_classes, config_file, rank=0):
        super().__init__()
        self.num_classes = num_classes
        self.rank = rank
        self.strides = [8, 16, 32]
        self.model, self.saved_index, anchors = self.build_model(config_file)
        self.register_buffer('anchors', \
            torch.FloatTensor(anchors).view(3, 3, 2) / \
            torch.FloatTensor(self.strides).view(3, 3, 1) \
        )
        self.apply(self.init_weight)
        
    def init_weight(self, m):
        type_t = type(m)
        if type_t is nn.Conv2d:
            pass
        elif type_t is nn.BatchNorm2d:
            m.eps = 1e-3
            m.momentum = 0.03
        elif type_t in [nn.leakyReLU, nn.ReLU, nn.ReLU6]:
            m.inplace = True
            
    def forward(self, x):
        pass
    
    def build_mode(self, config_file, input_channel=3):
        with open(config_file, "r") as f:
            self.yaml = yaml.load(f, Loader=yaml.FullLoader)
            
        all_layers_cfg_list = self.yaml["backbone"] + self.yaml["head"]
        anchors = self.yaml["anchors"]
        depth_multiple = self.yaml["depth_multiple"]
        width_multiple = self.yaml["width_multiple"]
        num_classes = self.num_classes
        num_anchor = len(anchors[0]) // 2  # [10,13, 16, 30, 33, 23]
        num_output = num_anchor * (num_classes + 5)
        all_layers_channels = [input_channel]
        all_layers = []
        saved_layer_index = []
        
        def parse_string(self, value):
            if value == "None":
                return None
            elif value == "True":
                return True
            elif value == "False":
                return False
            else:
                return value
            
        def make_divisible(x, divisor):
            # 制造整数倍
            return math.ceil(x / divisor) * divisor
        
        for layer_index, (from_index, repeat_count, module_name, args) in enumerate(all_layers_cfg_list):
            
            args = [self.parse_string(a) for a in args]
            module_class = eval(module_name)  # 反射
            
            if repeat_count > 1:
                repeat_cout = max(round(repeat_count * depth_multiple), 1)
                
            if module_class in [Conv, Bottleneck, SPP, Focus, BottleneckCSP]:
                channel_input, channel_output = all_layers_channels[from_index], args[0]
                
                if channel_out != num_output:
                    channel_output = make_divisible(channel_output * width_multiple, 8)
                
                # 把args第一个参数换成如下：
                args = [channel_input, channel_output, *args[1:]]
                
                if module_class in [BottleneckCSP]:
                    # repeat_count 在CSP模块下时，内部的ResUnit的参数
                    # 而不是CSP要重复多少次
                    args.insert(2, repeat_count)
                    repeat_count = 1
                elif module_class is Concat:
                    # 如果要concat，意味着from_index一般是list
                    channel_output = \
                        sum([all_layers_channels[-1 if x == -1 else x + 1] 
                          for x in from_index]
                    )
                elif module_class is Detect:
                    # detect也是from_index为list
                    reference_channel = [
                        all_layers_channels[x + 1] for x in from_index \
                                        ]
                    args = [num_classes, num_anchor, reference_channel]
                else:
                    channel_output = all_layers_channels[from_index]
                    
                if repeat_count > 1:
                    module_instance = nn.ModuleList([
                        module_class(*args) for _ in range(repeat_count)
                        ])
                else:
                    module_instance = module_class(*args)
                    
                module_instance.from_index = from_index
                module_instance.layer_index = layer_index
                all_layers.append(module_instance)
                all_layers_channels.append(channel_output)
                
                if not isinstance(from_index, list):
                    from_index = [from_index]
                    
                saved_layer_index.extend(filter(lambda x: x != -1, from_index))
                
            return nn.Sequential(*all_layers), sorted(saved_layer_index), anchors