# 2021/06/28

In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, MaxPool2D, UpSampling2D, Conv2DTranspose

In [71]:
### pytorch code
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1,),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # print(x.shape)
        x = self.conv(x)
        # print(x.shape)
        return x

In [70]:
x = torch.zeros(2, 1, 10, 3, dtype=torch.float)
model = double_conv(1, 4)
pred = model(x)

torch.Size([2, 1, 10, 3])
torch.Size([2, 4, 10, 3])


In [78]:
class double_conv(tf.keras.layers.Layer):
    def __init__(self, out_ch):
        super(double_conv, self).__init__()
        self.conv = Sequential([
            Conv2D(filters=out_ch, kernel_size=3, padding='same'),
            BatchNormalization(),
            ReLU(),
            Conv2D(filters=out_ch, kernel_size=3, padding='same'),
            BatchNormalization(),
            ReLU()
        ])
    
    def call(self, x):
    #   print(x.shape)
      x = self.conv(x)
    #   print(x.shape)
      return x

In [54]:
model = double_conv(4)
x = tf.zeros([2, 3, 10, 1], tf.float32)
pred = model(x)

(2, 3, 10, 1)
(2, 3, 10, 4)


In [32]:
tf.zeros([2, 2, 5, 3])

<tf.Tensor: shape=(2, 2, 5, 3), dtype=float32, numpy=
array([[[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]]], dtype=float32)>

In [31]:
torch.zeros([2, 2, 5, 3])

tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])

In [None]:
### pytorch code
class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x

In [47]:
class inconv(tf.keras.layers.Layer):
  def __init__(self, out_ch):
    super(inconv, self).__init__()
    self.conv = double_conv(out_ch)

  def call(self, x):
    x = self.conv(x)
    return x

model = inconv(4)
x = tf.zeros([2, 3, 10, 1], tf.float32)
pred = model(x)

(2, 3, 10, 1)
(2, 3, 10, 4)


In [None]:
### pytorch code
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x

In [57]:
class down(tf.keras.layers.Layer):
    def __init__(self, out_ch):
        super(down, self).__init__()
        self.mpconv = Sequential([
            MaxPool2D(pool_size=2),
            double_conv(out_ch)
        ])

    def call(self, x):
        # print(x.shape)
        x = self.mpconv(x)
        # print(x.shape)
        return x

model = down(4)
x = tf.zeros([2, 4, 10, 1], tf.float32)
pred = model(x)

(2, 4, 10, 1)
(2, 2, 5, 4)


In [74]:
### pytorch code
class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        print(x1.shape)
        x1 = self.up(x1)
        print(x1.shape)
        ### encoder의 feature map과 decoder의 feature map을 모두 사용하기 위해 둘의 shape를 맞춰주는 코드
        diffX = x1.size()[2] - x2.size()[2]
        diffY = x1.size()[3] - x2.size()[3]
        print(diffX, diffY)
        x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
                        diffY // 2, int(diffY / 2)))
        ###
        x = torch.cat([x2, x1], dim=1)
        print("x:", x.shape)
        x = self.conv(x)
        print("x:", x.shape)
        return x

In [77]:
x = torch.zeros(1, 1, 10, 3, dtype=torch.float)
y = torch.zeros(1, 1, 20, 6, dtype=torch.float)
model = up(2, 1, bilinear=False)
pred = model(x, y)

torch.Size([1, 1, 10, 3])
torch.Size([1, 1, 20, 6])
0 0
x: torch.Size([1, 2, 20, 6])
x: torch.Size([1, 1, 20, 6])


In [113]:
class up(tf.keras.layers.Layer):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        if bilinear:
            self.up = UpSampling2D(size=(2, 2))
        else:
            self.up = Conv2DTranspose(filters=in_ch//2, kernel_size=2, strides=2)

        self.conv = double_conv(out_ch)
    
    def call(self, x_dec, x_enc):
        x_dec = self.up(x_dec)
        print(x_dec.shape)
        diffX = x_dec.shape[1] - x_enc.shape[1]
        diffY = x_dec.shape[2] - x_enc.shape[2]
        print(diffX, diffY)
        ### paddings 설명 추가
        ### 홀수 입력 시에는 안되는 것 아닌지?
        paddings = [[0, 0],
                    [diffX // 2, int(diffX / 2)],
                    [diffY // 2, int(diffY / 2)],
                    [0, 0]]
        x_enc = tf.pad(x_enc, paddings=paddings)
        print(x_enc.shape)
        x = tf.concat([x_enc, x_dec], axis=-1)
        print(x_enc.shape, x_dec.shape)
        print(x.shape)
        x = self.conv(x)
        print(x.shape)
        return x

In [114]:
model = up(2, 1)
x = tf.zeros([2, 3, 10, 1], tf.float32)
y = tf.zeros([2, 4, 16, 1], tf.float32)
pred = model(x, y)

(2, 6, 20, 1)
2 4
(2, 6, 20, 1)
(2, 6, 20, 1) (2, 6, 20, 1)
(2, 6, 20, 2)
(2, 6, 20, 1)
