In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

Original model implementation at https://github.com/Li-Chongyi/Water-Net_Code/blob/master/model.py:

```python
  def model(self):


    with tf.variable_scope("main_branch") as scope:      

      conb0 = tf.concat(axis = 3, values = [self.images,self.images_wb,self.images_ce,self.images_gc]) 
      conv_wb1 = tf.nn.relu(conv2d(conb0, 16,128, k_h=7, k_w=7, d_h=1, d_w=1,name="conv2wb_1"))
      conv_wb2 = tf.nn.relu(conv2d(conv_wb1, 128,128, k_h=5, k_w=5, d_h=1, d_w=1,name="conv2wb_2"))
      conv_wb3 = tf.nn.relu(conv2d(conv_wb2, 128,128, k_h=3, k_w=3, d_h=1, d_w=1,name="conv2wb_3"))
      conv_wb4 = tf.nn.relu(conv2d(conv_wb3, 128,64, k_h=1, k_w=1, d_h=1, d_w=1,name="conv2wb_4"))
      conv_wb5 = tf.nn.relu(conv2d(conv_wb4, 64,64, k_h=7, k_w=7, d_h=1, d_w=1,name="conv2wb_5"))
      conv_wb6 = tf.nn.relu(conv2d(conv_wb5, 64,64, k_h=5, k_w=5, d_h=1, d_w=1,name="conv2wb_6"))
      conv_wb7 = tf.nn.relu(conv2d(conv_wb6, 64,64, k_h=3, k_w=3, d_h=1, d_w=1,name="conv2wb_7"))

      conv_wb77 =tf.nn.sigmoid(conv2d(conv_wb7, 64,3, k_h=3, k_w=3, d_h=1, d_w=1,name="conv2wb_77"))

      conb00 = tf.concat(axis = 3, values = [self.images,self.images_wb]) 
      conv_wb9 = tf.nn.relu(conv2d(conb00, 3,32, k_h=7, k_w=7, d_h=1, d_w=1,name="conv2wb_9"))
      conv_wb10 = tf.nn.relu(conv2d(conv_wb9, 32,32, k_h=5, k_w=5, d_h=1, d_w=1,name="conv2wb_10"))
      wb1 =tf.nn.relu(conv2d(conv_wb10, 32,3, k_h=3, k_w=3, d_h=1, d_w=1,name="conv2wb_11"))

      conb11 = tf.concat(axis = 3, values = [self.images,self.images_ce]) 
      conv_wb99 = tf.nn.relu(conv2d(conb11, 3,32, k_h=7, k_w=7, d_h=1, d_w=1,name="conv2wb_99"))
      conv_wb100 = tf.nn.relu(conv2d(conv_wb99, 32,32, k_h=5, k_w=5, d_h=1, d_w=1,name="conv2wb_100"))
      ce1 =tf.nn.relu(conv2d(conv_wb100, 32,3, k_h=3, k_w=3, d_h=1, d_w=1,name="conv2wb_111"))

      conb111 = tf.concat(axis = 3, values = [self.images,self.images_gc]) 
      conv_wb999 = tf.nn.relu(conv2d(conb111, 3,32, k_h=7, k_w=7, d_h=1, d_w=1,name="conv2wb_999"))
      conv_wb1000 = tf.nn.relu(conv2d(conv_wb999, 32,32, k_h=5, k_w=5, d_h=1, d_w=1,name="conv2wb_1000"))
      gc1 =tf.nn.relu(conv2d(conv_wb1000, 32,3, k_h=3, k_w=3, d_h=1, d_w=1,name="conv2wb_1111"))

      weight_wb,weight_ce,weight_gc=tf.split(conv_wb77,3,3)
      output1=tf.add(tf.add(tf.multiply(wb1,weight_wb),tf.multiply(ce1,weight_ce)),tf.multiply(gc1,weight_gc))

    return output1
```

where `ops.conv2d` is:

```python
def conv2d(input_, input_dim,output_dim, 
       k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
       name="conv2d"):
  with tf.variable_scope(name):
    w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
              initializer=tf.truncated_normal_initializer(stddev=stddev))
    conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')

    biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
    conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
    return conv
```

TF 1.4 docs for tf.nn.conv2d: https://github.com/tensorflow/docs/blob/r1.4/site/en/api_docs/api_docs/python/tf/nn/conv2d.md

In [56]:
class ConfidenceMapGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        # Confidence maps
        # Accepts input of size (N, 3*4, H, W)
        self.conv1 = nn.Conv2d(in_channels=12, out_channels=128, kernel_size=7, dilation=1, padding="same")
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, dilation=1, padding="same")
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, dilation=1, padding="same")
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, dilation=1, padding="same")
        self.relu4 = nn.ReLU()
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=7, dilation=1, padding="same")
        self.relu5 = nn.ReLU()
        self.conv6 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, dilation=1, padding="same")
        self.relu6 = nn.ReLU()
        self.conv7 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, dilation=1, padding="same")
        self.relu7 = nn.ReLU()
        self.conv8 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, dilation=1, padding="same")
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, wb, ce, gc):
        out = torch.cat([x, wb, ce, gc], dim=1)
        out = self.relu1(self.conv1(out))
        out = self.relu2(self.conv2(out))
        out = self.relu3(self.conv3(out))
        out = self.relu4(self.conv4(out))
        out = self.relu5(self.conv5(out))
        out = self.relu6(self.conv6(out))
        out = self.relu7(self.conv7(out))
        out = self.sigmoid(self.conv8(out))
        out1, out2, out3 = torch.split(out, [1, 1, 1], dim=1)
        return out1, out2, out3

In [57]:
cmg = ConfidenceMapGenerator()

In [60]:
cmg_a, cmg_b, cmg_c = cmg(torch.randn(16, 3, 112, 112), torch.randn(16, 3, 112, 112), torch.randn(16, 3, 112, 112), torch.randn(16, 3, 112, 112))

In [61]:
cmg_a.shape, cmg_b.shape, cmg_c.shape

(torch.Size([16, 1, 112, 112]),
 torch.Size([16, 1, 112, 112]),
 torch.Size([16, 1, 112, 112]))

In [27]:
class Refiner(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=6, out_channels=32, kernel_size=7, dilation=1, padding="same")
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, dilation=1, padding="same")
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, dilation=1, padding="same")
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()
    
    def forward(self, x, xbar):
        out = torch.cat([x, xbar], dim=1)
        out = self.relu1(self.conv1(out))
        out = self.relu2(self.conv2(out))
        out = self.relu3(self.conv3(out))
        return out

In [28]:
refiner = Refiner()

In [29]:
refiner_out = refiner(torch.randn(16, 3, 112, 112), torch.randn(16, 3, 112, 112))

In [30]:
refiner_out.shape

torch.Size([16, 3, 112, 112])

In [63]:
torch.mul(refiner_out, cmg_a).shape

torch.Size([16, 3, 112, 112])

In [64]:
class WaterNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.cmg = ConfidenceMapGenerator()
        self.wb_refiner = Refiner()
        self.ce_refiner = Refiner()
        self.gc_refiner = Refiner()
    
    def forward(self, x, wb, ce, gc):
        wb_cm, ce_cm, gc_cm = self.cmg(x, wb, ce, gc)
        refined_wb = self.wb_refiner(x, wb)
        refined_ce = self.ce_refiner(x, ce)
        refined_gc = self.gc_refiner(x, gc)
        return torch.mul(refined_wb, wb_cm) + torch.mul(refined_ce, ce_cm) + torch.mul(refined_gc, gc_cm)

In [65]:
waternet = WaterNet()

In [66]:
waternet_out = waternet(torch.randn(16, 3, 112, 112), torch.randn(16, 3, 112, 112), torch.randn(16, 3, 112, 112), torch.randn(16, 3, 112, 112))

In [67]:
waternet_out.shape

torch.Size([16, 3, 112, 112])

In [68]:
from torchsummary import summary

In [71]:
summary(
    waternet, 
    input_size=[
        (3, 112, 112), (3, 112, 112), (3, 112, 112), (3, 112, 112)
    ], 
    device="cpu",
    batch_size=4
)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [4, 128, 112, 112]          75,392
              ReLU-2         [4, 128, 112, 112]               0
            Conv2d-3         [4, 128, 112, 112]         409,728
              ReLU-4         [4, 128, 112, 112]               0
            Conv2d-5         [4, 128, 112, 112]         147,584
              ReLU-6         [4, 128, 112, 112]               0
            Conv2d-7          [4, 64, 112, 112]           8,256
              ReLU-8          [4, 64, 112, 112]               0
            Conv2d-9          [4, 64, 112, 112]         200,768
             ReLU-10          [4, 64, 112, 112]               0
           Conv2d-11          [4, 64, 112, 112]         102,464
             ReLU-12          [4, 64, 112, 112]               0
           Conv2d-13          [4, 64, 112, 112]          36,928
             ReLU-14          [4, 64, 1

Aight I think that's it.

Total 1M params, pretty small. For context, VGG has 100M+ params. No idea why torchsummary's input size blew up though.

In [74]:
for i, j in waternet.state_dict().items():
    print(i, j.shape)

cmg.conv1.weight torch.Size([128, 12, 7, 7])
cmg.conv1.bias torch.Size([128])
cmg.conv2.weight torch.Size([128, 128, 5, 5])
cmg.conv2.bias torch.Size([128])
cmg.conv3.weight torch.Size([128, 128, 3, 3])
cmg.conv3.bias torch.Size([128])
cmg.conv4.weight torch.Size([64, 128, 1, 1])
cmg.conv4.bias torch.Size([64])
cmg.conv5.weight torch.Size([64, 64, 7, 7])
cmg.conv5.bias torch.Size([64])
cmg.conv6.weight torch.Size([64, 64, 5, 5])
cmg.conv6.bias torch.Size([64])
cmg.conv7.weight torch.Size([64, 64, 3, 3])
cmg.conv7.bias torch.Size([64])
cmg.conv8.weight torch.Size([3, 64, 3, 3])
cmg.conv8.bias torch.Size([3])
wb_refiner.conv1.weight torch.Size([32, 6, 7, 7])
wb_refiner.conv1.bias torch.Size([32])
wb_refiner.conv2.weight torch.Size([32, 32, 5, 5])
wb_refiner.conv2.bias torch.Size([32])
wb_refiner.conv3.weight torch.Size([3, 32, 3, 3])
wb_refiner.conv3.bias torch.Size([3])
ce_refiner.conv1.weight torch.Size([32, 6, 7, 7])
ce_refiner.conv1.bias torch.Size([32])
ce_refiner.conv2.weight torch