In [1]:
import torch as pt
import torchvision as ptv

# 带批标准化的2d卷积

In [13]:
class NormConv2d(pt.nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, \
                 dilation=1, groups=1, bias=True,momentum=0.1,eps=1e-05,affine=True):
        super(NormConv2d,self).__init__()
        self.conv = pt.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.norm = pt.nn.BatchNorm2d(out_channels,momentum=momentum,eps=eps,affine=affine)
    
    def forward(self,x):
        return self.norm(self.conv(x))

In [14]:
test = NormConv2d(3,10,3,padding=1)
print(test)
test(pt.autograd.Variable(pt.randn(1,3,9,9))).size()

NormConv2d (
  (conv): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (norm): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True)
)


torch.Size([1, 10, 9, 9])

# 带批标准化的全连接层

In [15]:
class NormLinear(pt.nn.Module):
    def __init__(self,in_features, out_features, bias=True,momentum=0.1,eps=1e-05,affine=True):
        super(NormLinear,self).__init__()
        self.fc = pt.nn.Linear(in_features,out_features,bias=bias)
        self.norm = pt.nn.BatchNorm1d(out_features,eps,momentum,affine)
        
    def forward(self,x):
        return self.norm(self.fc(x))

In [16]:
test = NormLinear(4,16)
print(test)
test(pt.autograd.Variable(pt.randn(1,4)))

NormLinear (
  (fc): Linear (4 -> 16)
  (norm): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True)
)


Variable containing:

Columns 0 to 12 
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 13 to 15 
    0     0     0
[torch.FloatTensor of size 1x16]

# 连续卷积层

In [2]:
class ContinuousConv2d(pt.nn.Module):
    def __init__(self,channel_list,kernel_list,padding_list=None,stride_list=None,conv=pt.nn.Conv2d,afunc=pt.nn.ReLU,**other):
        super(ContinuousConv2d,self).__init__()
        if padding_list is None:
            padding_list = [int(x // 2) for x in kernel_list]
        if stride_list is None:
            stride_list = [1 for _ in range(len(kernel_list))]
        conv_list = []
        for i in range(len(channel_list) - 1):
            conv_list.append(conv(channel_list[i],channel_list[i + 1],kernel_list[i],stride=stride_list[i],padding=padding_list[i],**other))
            conv_list.append(afunc())
        self.conv = pt.nn.Sequential(*conv_list[:-1])
        
    def forward(self,x):
        return self.conv(x)

In [25]:
test = ContinuousConv2d([1,8,16],[3,3])
print(test)
test(pt.autograd.Variable(pt.randn(1,1,10,10)))

ContinuousConv2d (
  (conv): Sequential (
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU ()
    (2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)


Variable containing:
(0 ,0 ,.,.) = 
  0.0524  0.4676  0.4081  ...   0.2864  0.2125 -0.0094
  0.0184  0.2728  0.0101  ...   0.2032  0.2477  0.2890
  0.3800  0.4502  0.0744  ...  -0.0487  0.3251  0.2397
           ...             ⋱             ...          
  0.2249 -0.0121  0.0616  ...   0.1377  0.2923  0.1872
  0.0310  0.2627  0.0765  ...   0.1064  0.0898 -0.0992
 -0.1131  0.1662  0.1120  ...  -0.1280  0.1784  0.0268

(0 ,1 ,.,.) = 
 -0.0101 -0.1852 -0.1042  ...   0.0359 -0.0763 -0.4257
  0.1514 -0.2760 -0.0718  ...   0.0341 -0.2317 -0.0404
 -0.1742  0.2562  0.2920  ...  -0.0881  0.0118  0.0711
           ...             ⋱             ...          
 -0.2541 -0.3063  0.0141  ...   0.1248 -0.3487 -0.3627
 -0.1559  0.1848 -0.0870  ...  -0.0911  0.1414 -0.3398
  0.0629 -0.0027  0.0227  ...   0.0360  0.0105 -0.0180

(0 ,2 ,.,.) = 
 -0.2058  0.0845 -0.2227  ...  -0.3317 -0.1473 -0.2814
 -0.2114 -0.4305 -0.4807  ...  -0.3583 -0.2428 -0.1912
 -0.0444 -0.1291 -0.1698  ...  -0.3227 -0.3870 -0.07

# 连续MLP层

In [34]:
class ContinuousLinear(pt.nn.Module):
    def __init__(self,feature_list,bias=True,linear=pt.nn.Linear,afunc=pt.nn.ReLU,**other):
        super(ContinuousLinear,self).__init__()
        linear_list = []
        for i in range(len(feature_list) - 1):
            linear_list.append(linear(feature_list[i],feature_list[i + 1],bias=bias,**other))
            linear_list.append(afunc())
        self.fc = pt.nn.Sequential(*linear_list[:-1])
        
    def forward(self,x):
        return self.fc(x)

In [35]:
test = ContinuousLinear([10,16,32])
print(test)

ContinuousLinear (
  (fc): Sequential (
    (0): Linear (10 -> 16)
    (1): ReLU ()
    (2): Linear (16 -> 32)
  )
)


In [36]:
test(pt.autograd.Variable(pt.randn(1,10)))

Variable containing:

Columns 0 to 9 
-0.1863  0.0415 -0.2109 -0.2867 -0.1637 -0.5474 -0.1163 -0.1524  0.4732 -0.1101

Columns 10 to 19 
 0.3767 -0.4234 -0.0083  0.1701  0.1976  0.1648  0.1711  0.0952  0.3829 -0.4260

Columns 20 to 29 
-0.1192  0.1996 -0.1150  0.1657  0.5129 -0.1024 -0.0795  0.5058  0.3117 -0.1583

Columns 30 to 31 
 0.6542  0.3360
[torch.FloatTensor of size 1x32]

# 残差网络单元

In [36]:
class ResNet_2DCell(pt.nn.Module):
    def __init__(self,channel_list,kernel_list,fill_mode="fc",conv_gen=pt.nn.Conv2d,afunc=pt.nn.ReLU,**other):
        super(ResNet_2DCell,self).__init__()
        self.conv = ContinuousConv2d(channel_list,kernel_list,padding_list=None,stride_list=None,conv=conv_gen,afunc=afunc,**other)
        if channel_list[0] == channel_list[-1]:
            self.x_handle = lambda x:x
        else:
            if fill_mode == "fc":
                self.x_handle = self.fc_fill(channel_list[0],channel_list[-1])
            else:
                self.x_handle = self.zero_fill(channel_list[0],channel_list[-1])
                
    def forward(self,x):
        return self.conv(x) + self.x_handle(x)
    
    def fc_fill(self,in_channel,out_channel):
        self.fc = pt.nn.Conv2d(in_channel,out_channel,1)
        return self._fill_fc
    
    def _fill_fc(self,x):
        return self.fc(x)
    
    def zero_fill(self,in_channel,out_channel):
        self.zeros_channel = abs(out_channel - in_channel)
        self.zeros = pt.nn.Parameter(pt.zeros(1,self.zeros_channel,1,1))
        return self._fill_zero
    
    def _fill_zero(self,x):
        zero_size = list(x.size())
        zero_size[1] = self.zeros_channel
        return pt.cat([self.zeros.expand(zero_size),x],dim=1)

In [37]:
test = ResNet_2DCell([4,8,8],[3,3],fill_mode="zero")
print(test)
test(pt.autograd.Variable(pt.randn(1,4,12,12)))

ResNet_2DCell (
  (conv): ContinuousConv2d (
    (conv): Sequential (
      (0): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU ()
      (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
)


Variable containing:
(0 ,0 ,.,.) = 
 -0.0555  0.1706  0.0639  ...  -0.0556  0.2508 -0.0897
 -0.0964 -0.2122  0.1826  ...   0.2391 -0.0813  0.1576
  0.0580  0.1242 -0.3076  ...   0.0564  0.1631  0.1593
           ...             ⋱             ...          
 -0.1037 -0.0790  0.2274  ...  -0.4336 -0.2406  0.0980
  0.1578 -0.5764 -0.0776  ...   0.3832 -0.1542  0.0207
  0.1452  0.1513  0.2245  ...  -0.0869  0.2425  0.1719

(0 ,1 ,.,.) = 
 -0.3203 -0.0553 -0.0904  ...  -0.3223  0.0240 -0.1086
  0.0927 -0.0232 -0.1161  ...   0.0373 -0.2010  0.2305
 -0.0440  0.3771  0.1093  ...  -0.1854 -0.0639 -0.0560
           ...             ⋱             ...          
 -0.1814 -0.2062 -0.0132  ...  -0.2375  0.1535  0.1873
 -0.0464 -0.0889  0.2539  ...   0.1940  0.1197 -0.1147
 -0.0131  0.2106  0.3264  ...   0.1869  0.1519 -0.0704

(0 ,2 ,.,.) = 
 -0.0863  0.1882  0.0667  ...  -0.0438 -0.0350 -0.0461
  0.1552  0.0801 -0.0268  ...  -0.1854  0.3407 -0.0192
  0.1695  0.0870  0.3840  ...   0.6435 -0.1532  0.02

In [39]:
# a = pt.zeros(1,4,1,1)
# a = a.expand(8,4,1,1)
# pt.cat((a,pt.ones(8,4,6,6)),dim=1)