In [None]:
# Mout Google Drive
# https://towardsdatascience.com/google-drive-google-colab-github-dont-just-read-do-it-5554d5824228
from google.colab import drive
ROOT = "/content/drive"
drive.mount(ROOT)
# %pwd %ls
# run github settings
%run /content/drive/MyDrive/CNNStanford/pytorch/pytorch_sandbox/Colab_Helper.ipynb

In [None]:
MESSAGE = "clean file & gitignore again"
!git config --global user.email "ronyginosar@mail.huji.ac.il"
!git config --global user.name "ronyginosar"
!git add .

In [None]:
!git commit -m "{MESSAGE}"
!git push "{GIT_PATH}"

In [None]:
import torch
import torch.nn as nn
from torchinfo import summary
import sys
from torchvision import transforms

#%%
# Optional additions:
# ‘Valid’ [default]  vs. ‘Same’ padding:
#   https://stackoverflow.com/questions/37674306/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-t
# (not in original breif):
# feature_channels an enum and input
# 'depth' input for easier size changes
# module list using depth to create down and up directions, rather than code repetition, e.g.:
#  self.down_path = nn.ModuleList()
#         for i in range(depth):
#             self.down_path.append(
#                 UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
#             )
#             prev_channels = 2 ** (wf + i)

#%%

# Implement UNet as a nn Module
class UNet(nn.Module):
    """
    U-Net: Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015)
    https://arxiv.org/abs/1505.04597
    """
    def __init__(self,
                 in_channels=1,
                 out_channels=2,
                 up_mode='upconv'
                 ):
        super().__init__()
        # Support any number of output classes, input channels, 'depth'...
        feature_channels = 64

        # Support Transposed Convolutions [default] vs. Bilinear upscaling
        assert up_mode in ('upconv', 'upsample')
        bilinear = False
        if up_mode == 'upsample':   bilinear = True
        elif up_mode == 'upconv':   bilinear = False

        # submodules:
        # conv 3*3 , Relu => *2
        self.in_block = conv_relu_2(in_channels, out_channels=feature_channels)
        # downsampling maxpool, at each downsampling step we double the number of feature channels
        self.down1 = maxpool(in_channels=feature_channels, out_channels=feature_channels*2)  # 64,128
        feature_channels *= 2
        self.down2 = maxpool(in_channels=feature_channels, out_channels=feature_channels*2)  # 128,256
        feature_channels *= 2
        self.down3 = maxpool(in_channels=feature_channels, out_channels=feature_channels*2)  # 256, 512
        feature_channels *= 2
        self.down4 = maxpool(in_channels=feature_channels, out_channels=feature_channels*2)  # 512, 1024
        feature_channels *= 2

        # copy and crop & up-conv 2*2
        self.up1 = UpNConcat(feature_channels, bilinear)
        feature_channels //= 2
        self.up2 = UpNConcat(feature_channels, bilinear)
        feature_channels //= 2
        self.up3 = UpNConcat(feature_channels, bilinear)
        feature_channels //= 2
        self.up4 = UpNConcat(feature_channels, bilinear)
        feature_channels //= 2

        # conv 1*1
        self.out = out_conv(in_channels=feature_channels, out_channels=out_channels)  # 64, 2

    def forward(self, x):
        x1 = self.in_block(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)   # bottom
        x = self.up1(x4, x5)  # to_copy, to_up_append
        x = self.up2(x3, x)
        x = self.up3(x2, x)
        x = self.up4(x1, x)
        out = self.out(x)
        return out


#%% Sequentials

def conv_relu_2(in_channels, out_channels, *args, **kwargs):
    """pair of (convolution => ReLU)
    repeated application of two 3x3 convolutions (unpadded convolutions), each followed by
    a rectified linear unit (ReLU)"""
    return nn.Sequential(conv_relu(in_channels, out_channels, *args, **kwargs),
                         conv_relu(out_channels, out_channels, *args, **kwargs)
                         )


def maxpool(in_channels, out_channels, kernel_size=2, *args, **kwargs):
    """max pool 2*2 (down) followed by pair of (convolution => ReLU)
    2x2 max pooling operation with stride 2 for downsampling"""
    return nn.Sequential(nn.MaxPool2d(kernel_size),
                         conv_relu_2(in_channels, out_channels, *args, **kwargs)
                         )


def up_conv(in_channels, out_channels):
    """upsampling of the feature map followed by a 2x2 convolution (up-convolution) that halves the
    number of feature channels"""
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)


def out_conv(in_channels, out_channels, *args, **kwargs):
    """At the final layer a 1x1 convolution is used to map each 64-component
     feature vector to the desired number of classes."""
    return nn.Conv2d(in_channels, out_channels, kernel_size=1, *args, **kwargs)


#%% base module of a 3x3 convolution (unpadded convolutions), followed by a (ReLU)
def conv_relu(in_channels, out_channels, *args, **kwargs):
    """(convolution => ReLU)
    base module of a 3x3 convolution (unpadded convolutions), followed by a (ReLU)"""
    return nn.Sequential(nn.Conv2d(in_channels, out_channels, padding=0, kernel_size=3, *args, **kwargs),
                         nn.ReLU()
                         )

#%% Classes

class UpNConcat(nn.Module):
    """upsampling of the feature map followed by a 2x2 convolution (up-convolution) that halves the
    number of feature channels, a concatenation with the correspondingly cropped
    feature map from the contracting path, and two 3x3 convolutions, each fol-
    lowed by a ReLU.
    The cropping is necessary due to the loss of border pixels in every convolution.
    """
    def __init__(self, feature_channels, bilinear=False):
        super().__init__()
        # upsample input
        if bilinear:  # if bilinear, use the normal convolutions to reduce the number of channels
            self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                    nn.Conv2d(in_channels=feature_channels,
                                              out_channels=(feature_channels//2), kernel_size=1))
        else:  # learning version
            self.up = up_conv(in_channels=feature_channels, out_channels=(feature_channels//2))
        # after crop and concat - double conv
        self.conv = conv_relu_2(in_channels=feature_channels, out_channels=(feature_channels//2))

    def forward(self, to_copy, to_up_append):
        # upsample input to_append
        upsampled = self.up(to_up_append)
        '''
        optional:
        # pad upsampled to match to_copy (which is a tad larger), rather than crop larger (very unclear grammar):
        the original paper encourage down-sampling without padding the same as the up-sampling without zero-padding, 
        which can avoid corrupting semantic information. 
        This is the one of the reason for which the overlap-tile strategy was proposed
        https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        '''
        # crop to_copy
        x_dim = 2; y_dim = 3
        H, W = upsampled.size()[x_dim], upsampled.size()[y_dim]
        cropped_copy = transforms.CenterCrop([H, W])(to_copy)  # from: https://amaarora.github.io/2020/09/13/unet.html
        # concat
        cat_out = torch.cat([cropped_copy, upsampled], dim=1)
        # double conv
        out = self.conv(cat_out)
        return out

#%%
def main(argv):

    # Create a net instance
    net = UNet(in_channels=1, out_channels=2, depth=5, up_mode='upconv')

    # - set the network  to 'eval' mode
    net.eval()
    # - generate a random input batch [Nx3x32x32]
    in_images = 1
    x = torch.randn((in_images, 1, 572, 572))

    # - feed the batch through the network (forward), using CPU
    if not torch.cuda.is_available():
        print("running on CPU")
        y = net(x)

    # - feed the batch through the network (forward), using GPU
    #   hint: make sure both input and network are moved to GPU.
    if torch.cuda.is_available():
        print("running on GPU")
        device = torch.device('cuda:0')
        x = x.to(device)
        net = net.to(device)
        y = net(x)

    # - Make a nice summary table of the network using 'torchinfo'
    # Display the following columns: input size, output size, #params, #MACs per layer
    # ex. from torchinfo import summary ; summary(net, input_size=(1,3,224,224),...)
    # print needed for notebook mode

    print(summary(net,
                  input_size=(in_images, 1, 572, 572),
                  col_names=["input_size", "output_size", "num_params"],
                  col_width=15,
                  depth=10))


#%%
if __name__ == "__main__":
    main(sys.argv)
