We import necessairy packages

In [1]:
require 'nn'
net = nn.Sequential()

### *Global Architecture*
<p>![](./model/general_structure.png)

*Equation to compute input feature*
<p>$owidth  = floor((width  + 2*padW - kW) / dW + 1)$
<p>$oheight = floor((height + 2*padH - kH) / dH + 1)$

 *Typical structure 1 ---- d1 to d2 of ResNet50*
<p>![](./model/typical_structure_1.png)

In [10]:
function convBlock1(d0, d1, d2)
    local cat = nn.ConcatTable()
    
    local branch1 = nn.Sequential()
    branch1:add(nn.SpatialConvolution(d0, d1, 1, 1))
    -- branch1:add(nn.SpatialBatchNormalization(d1))
    branch1:add(nn.ReLU())
    branch1:add(nn.SpatialConvolution(d1, d1, 3, 3, 1, 1, 1, 1))
    -- branch1:add(nn.SpatialBatchNormalization(d1))
    branch1:add(nn.ReLU())
    branch1:add(nn.SpatialConvolution(d1, d2, 1, 1))
    -- branch1:add(SpatialBatchNormalization(d2))
    local branch2 = nn.Sequential()
    branch2:add(nn.Identity())
    
    cat:add(branch1)
    cat:add(branch2)
    net:add(cat)
    net:add(nn.CAddTable())
end

 *Typical structure 2 ----- d1 to d2 of ResNet50*
<p>![](./model/typical_structure_2.png)

In [11]:
function convBlock2(s, d0, d1, d2)
    local cat = nn.ConcatTable()
    
    local branch1 = nn.Sequential()
    branch1:add(nn.SpatialConvolution(d0, d1, 1, 1, s, s))
    --branch1:add(nn.SpatialBatchNormalization(d1))
    branch1:add(nn.ReLU())
    branch1:add(nn.SpatialConvolution(d1, d1, 3, 3, 1, 1, 1, 1))
    -- branch1:add(nn.SpatialBatchNormalization(d1))
    branch1:add(nn.ReLU())
    branch1:add(nn.SpatialConvolution(d1, d2, 1, 1, 1, 1))
    -- branch1:add(nn.SpatialBatchNormalization(d2))
    local branch2 = nn.Sequential()
    branch2:add(nn.SpatialConvolution(d0, d2, 1, 1, s, s))
    -- branch2:add(nn.SpatialBatchNormalization(d2))
    
    cat:add(branch1)
    cat:add(branch2)
    net:add(cat)
    net:add(nn.CAddTable())
end

### *up-projection*
<p> ![](./model/up_projection.png)
<p> $owidth  = (width  - 1) * dW - 2*padW + kW + adjW$
<p> $oheight = (height - 1) * dH - 2*padH + kH + adjH$

In [3]:
-- implement simple version of up-convolution
module = nn.SpatialFullConvolution(3, 4, 2, 2, 2, 2)
test_tensor = torch.rand(3, 10, 8)
print(test_tensor)
print(module:forward(test_tensor))

(1,.,.) = 
  0.3510  0.8965  0.1771  0.2964  0.0756  0.9900  0.2502  0.1683
  0.9663  0.2215  0.6616  0.6700  0.0140  0.8059  0.2070  0.2623
  0.4283  0.2288  0.1214  0.1835  0.4322  0.9043  0.8306  0.4843
  0.0744  0.5301  0.4120  0.7430  0.9963  0.8685  0.9603  0.7024
  0.0025  0.0093  0.5117  0.1075  0.0840  0.5810  0.0124  0.1445
  0.3673  0.5886  0.7817  0.1717  0.8611  0.9383  0.4569  0.0229
  0.3068  0.4965  0.2727  0.0286  0.9265  0.4755  0.3327  0.5246
  0.4313  0.3779  0.0560  0.6320  0.0253  0.0421  0.0982  0.4765
  0.7283  0.1565  0.4620  0.6260  0.2125  0.1111  0.7072  0.4412
  0.9286  0.6560  0.7593  0.5818  0.9576  0.2324  0.7174  0.4389

(2,.,.) = 
  0.9527  0.2666  0.3942  0.7933  0.3135  0.1783  0.4429  0.4553
  0.4045  0.6094  0.5859  0.2418  0.2703  0.4132  0.0301  0.7754
  0.8879  0.7070  0.0500  0.6373  0.7745  0.1209  0.1227  0.5153
  0.5187  0.5497  0.8286  0.0481  0.9426  0.8888  0.5560  0.8578
  0.5306  0.9919  0.2784  0.8953  0.0304  0.5168  0.0108  0.0279
  

(1,.,.) = 
 Columns 1 to 9
 -0.0329  0.0340 -0.1175  0.1193  0.0091 -0.0121  0.1806  0.0374  0.1119
 -0.0369 -0.1168  0.1741 -0.2666 -0.0201 -0.0637  0.0424 -0.0049 -0.0055
 -0.1405  0.1350 -0.0204 -0.0000 -0.1044  0.0812 -0.1011  0.0742  0.0175
  0.1737 -0.2910 -0.0388 -0.0861  0.0720 -0.2131  0.1106 -0.2140 -0.0564
  0.0436  0.0546  0.1885  0.0222  0.0969 -0.0247 -0.0185 -0.0068 -0.0208
  0.0231 -0.0959  0.0326  0.0121  0.0300 -0.0117 -0.0530 -0.0775  0.0115
  0.2010 -0.0129  0.0941  0.0709 -0.0259  0.0438 -0.1455  0.0801  0.0696
  0.0072  0.0482  0.1088 -0.0928 -0.0023 -0.1258  0.1366 -0.2500  0.2084
 -0.0128 -0.0458  0.0655 -0.0256 -0.0673  0.0460 -0.0039 -0.0143 -0.0073
 -0.0977 -0.0393 -0.1125 -0.0023  0.0693 -0.1665 -0.0977 -0.0550 -0.0193
  0.1202  0.0356 -0.0621  0.0641  0.0854  0.1106 -0.0244 -0.0060 -0.0521
  0.0844 -0.0486  0.0879 -0.1791  0.2262 -0.1476 -0.0736 -0.0777  0.1270
 -0.0563  0.0214  0.0080  0.0682  0.1372  0.0338  0.1201 -0.0296 -0.0198
 -0.0534 -0.1194  0.0204

1354  0.1492  0.0796  0.2265  0.1627

Columns 10 to 16
  0.2226  0.0219  0.1108  0.1165  0.2782  0.1300  0.2634
  0.3328  0.0672  0.1080  0.1830  0.2705  0.1484  0.3074
  0.2598  0.0092  0.1784  0.2126  0.1605  0.0413  0.2618
  0.3171  0.0879  0.1662  0.1134  0.2765  0.0625  0.3472
  0.2686  0.0524  0.1492  0.0679  0.2002  0.0521  0.2018
  0.2920  0.1359  0.0957  0.2083  0.0837  0.0646  0.2673
  0.1776 -0.1055  0.2475 -0.0532  0.1645 -0.0653  0.2522
  0.2204  0.0651  0.2003  0.0397  0.1594  0.0656  0.2400
  0.2344  0.0319  0.1825  0.2571  0.2165  0.2271  0.2282
  0.2634  0.0431  0.2519  0.1882  0.2927  0.2142  0.2491
  0.2311 -0.0914  0.2244  0.1248  0.1791  0.1158  0.2751
  0.2089  0.0733  0.1677  0.1161  0.2260  0.0951  0.3845
  0.1533  0.1198  0.1480  0.0664  0.2740  0.0894  0.1930
  0.1250  0.0684  0.2419  0.1400  0.2808  0.1151  0.2176
  0.2921  0.0960  0.3268  0.1475  0.2447 -0.0192  0.3008
  0.3164  0.1554  0.3603  0.1176  0.3371  0.1198  0.2749
  0.2746  0.1150  0.2643  0.0841 

 -0.4772 -0.1984 -0.3912 -0.2775 -0.4648 -0.3330 -0.3771 -0.2104
 -0.1767 -0.0946 -0.2454 -0.0372 -0.2123 -0.1199 -0.1807 -0.1163 -0.1520
 -0.2843 -0.5237 -0.3697 -0.4628 -0.3480 -0.2471 -0.1833 -0.3905 -0.2457
 -0.2643 -0.0343 -0.1723 -0.1161 -0.1542 -0.1243 -0.0536 -0.2715 -0.1152
 -0.3556 -0.3817 -0.4119 -0.4683 -0.2266 -0.4898 -0.3468 -0.3331 -0.4917
 -0.2027 -0.1060 -0.2392 -0.0360 -0.1144 -0.1989 -0.2015 -0.0802 -0.1820
 -0.1423 -0.3236 -0.1465 -0.4627 -0.3036 -0.3588 -0.1330 -0.4444 -0.2470
 -0.1971 -0.1061 -0.1084 -0.1975 -0.1278 -0.1802 -0.1838 -0.1043 -0.0900
 -0.4049 -0.3917 -0.3229 -0.3975 -0.5478 -0.4166 -0.1530 -0.4247 -0.3503
 -0.1589 -0.1198 -0.1590 -0.1048 -0.2297 -0.0357 -0.2429 -0.0613 -0.0786
 -0.1555 -0.4760 -0.2605 -0.5600 -0.2996 -0.5344 -0.2712 -0.3507 -0.4817
 -0.1202 -0.2043 -0.1916 -0.0979 -0.2788  0.0056 -0.1490 -0.1461 -0.1858
 -0.2932 -0.3050 -0.3505 -0.4468 -0.2988 -0.4825 -0.4471 -0.4512 -0.1431
 -0.1578 -0.0986 -0.2396 -0.0669 -0.1596 -0.1207 -0.1247 -0

In [12]:
-- input size 304x228x3
net:add(nn.SpatialConvolution(3, 64, 7, 7, 2, 2, 3, 3))
net:add(nn.SpatialMaxPooling(3, 3, 2, 2, 1, 1))
-- net:add(nn.SpatialBatchNormalization(64))
convBlock2(1, 64, 64, 256)
convBlock1(256, 64, 256)
convBlock2(2, 256, 128, 512)

In [13]:
print('CRN net\n' .. net:__tostring())

CRN net
nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> output]
  (1): nn.SpatialConvolution(3 -> 64, 7x7, 2,2, 3,3)
  (2): nn.SpatialMaxPooling(3x3, 2,2, 1,1)
  (3): nn.ConcatTable {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
      |      (1): nn.SpatialConvolution(64 -> 64, 1x1)
      |      (2): nn.ReLU
      |      (3): nn.SpatialConvolution(64 -> 64, 3x3, 1,1, 1,1)
      |      (4): nn.ReLU
      |      (5): nn.SpatialConvolution(64 -> 256, 1x1)
      |    }
       `-> (2): nn.Sequential {
             [input -> (1) -> output]
             (1): nn.SpatialConvolution(64 -> 256, 1x1)
           }
       ... -> output
  }
  (4): nn.CAddTable
  (5): nn.ConcatTable {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
      |      (1): nn.SpatialConvolution(256 -> 64, 1x1)
      |      (2): nn.ReLU
      |      (3): nn.SpatialConvo

In [14]:
a_test = torch.rand(3, 304, 228)
print(#a_test)

   3
 304
 228
[torch.LongStorage of size 3]



In [15]:
a_test_output = net:forward(a_test)
print(#a_test_output)

 512
  38
  29
[torch.LongStorage of size 3]

