In [83]:
import numpy as np
import torch
import sigpy as sp
import matplotlib.pyplot as plt

In [79]:
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import torch
from torch import nn
from torch.nn import functional as F

class ConvBlock(nn.Module):
    """
    A Convolutional Block that consists of two convolution layers each followed by
    instance normalization, LeakyReLU activation and dropout.
    """

    def __init__(self, in_chans, out_chans, drop_prob):
        """
        Args:
            in_chans (int): Number of channels in the input.
            out_chans (int): Number of channels in the output.
            drop_prob (float): Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.drop_prob = drop_prob

        self.layers = nn.Sequential(
            nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout2d(drop_prob),
            nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout2d(drop_prob)
        )

    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        return self.layers(input)

    def __repr__(self):
        return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans}, ' \
            f'drop_prob={self.drop_prob})'


class TransposeConvBlock(nn.Module):
    """
    A Transpose Convolutional Block that consists of one convolution transpose layers followed by
    instance normalization and LeakyReLU activation.
    """

    def __init__(self, in_chans, out_chans):
        """
        Args:
            in_chans (int): Number of channels in the input.
            out_chans (int): Number of channels in the output.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans

        self.layers = nn.Sequential(
            nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
        )

    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        return self.layers(input)

    def __repr__(self):
        return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans})'


# class up_conv(nn.Module):
#     def __init__(self, chans, num_pool_layers, drop_prob, sites):
    

class Branched_UnetModel(nn.Module):
    """
    PyTorch implementation of a U-Net model.
    This is based on:
        Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks
        for biomedical image segmentation. In International Conference on Medical image
        computing and computer-assisted intervention, pages 234–241. Springer, 2015.
    """

    def __init__(self, in_chans, out_chans, chans, num_pool_layers, drop_prob, sites):
        """
        Args:
            in_chans (int): Number of channels in the input to the U-Net model.
            out_chans (int): Number of channels in the output to the U-Net model.
            chans (int): Number of output channels of the first convolution layer.
            num_pool_layers (int): Number of down-sampling and up-sampling layers.
            drop_prob (float): Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.chans = chans
        self.num_pool_layers = num_pool_layers
        self.drop_prob = drop_prob
        self.sites = sites
        
        self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
        ch = chans
        for i in range(num_pool_layers - 1):
            self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)]
            ch *= 2
        self.conv = ConvBlock(ch, ch * 2, drop_prob)
        
        
        
        self.up_conv = nn.ModuleList()
        self.up_transpose_conv = nn.ModuleList()
        
        for i in range(num_pool_layers - 1):
            self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)]
            self.up_conv += [ConvBlock(ch * 2, ch, drop_prob)]
            ch //= 2

        self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)]
        self.up_conv += [
            nn.Sequential(
                ConvBlock(ch * 2, ch, drop_prob),
                nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
            )]
 
       #ModuleLists for storing site specific Decoder weights
        self.Decoder_up_list = torch.nn.ModuleList(
                        [ self.up_conv  for i in range(self.sites)])
    
        self.Decoder_up_trans_list = torch.nn.ModuleList(
                        [ self.up_transpose_conv  for i in range(self.sites)])

        
    def forward(self, site, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        stack = []
        output = input.unsqueeze(1)

        # Apply down-sampling layers
        for i, layer in enumerate(self.down_sample_layers):
            output = layer(output)
            stack.append(output)
            output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)

        output = self.conv(output)

        # Apply up-sampling layers
        for transpose_conv, conv in zip(self.Decoder_up_trans_list[site], self.Decoder_up_list[site]):
            downsample_layer = stack.pop()
            output = transpose_conv(output)

            # Reflect pad on the right/botton if needed to handle odd input dimensions.
            padding = [0, 0, 0, 0]
            if output.shape[-1] != downsample_layer.shape[-1]:
                padding[1] = 1 # Padding right
            if output.shape[-2] != downsample_layer.shape[-2]:
                padding[3] = 1 # Padding bottom
            if sum(padding) != 0:
                output = F.pad(output, padding, "reflect")

            output = torch.cat([output, downsample_layer], dim=1)
            output = conv(output)
            
        return output.squeeze(1)

    
class OldUnetModel(nn.Module):
    """
    PyTorch implementation of a U-Net model.
    This is based on:
        Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks
        for biomedical image segmentation. In International Conference on Medical image
        computing and computer-assisted intervention, pages 234–241. Springer, 2015.
    """

    def __init__(self, in_chans, out_chans, chans, num_pool_layers, drop_prob):
        """
        Args:
            in_chans (int): Number of channels in the input to the U-Net model.
            out_chans (int): Number of channels in the output to the U-Net model.
            chans (int): Number of output channels of the first convolution layer.
            num_pool_layers (int): Number of down-sampling and up-sampling layers.
            drop_prob (float): Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.chans = chans
        self.num_pool_layers = num_pool_layers
        self.drop_prob = drop_prob

        self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
        ch = chans
        for i in range(num_pool_layers - 1):
            self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)]
            ch *= 2
        self.conv = ConvBlock(ch, ch * 2, drop_prob)

        self.up_conv = nn.ModuleList()
        self.up_transpose_conv = nn.ModuleList()
        for i in range(num_pool_layers - 1):
            self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)]
            self.up_conv += [ConvBlock(ch * 2, ch, drop_prob)]
            ch //= 2

        self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)]
        self.up_conv += [
            nn.Sequential(
                ConvBlock(ch * 2, ch, drop_prob),
                nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
            )]

    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        stack = []
        output = input.unsqueeze(1)

        # Apply down-sampling layers
        for i, layer in enumerate(self.down_sample_layers):
            output = layer(output)
            stack.append(output)
            output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)

        output = self.conv(output)

        # Apply up-sampling layers
        for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
            downsample_layer = stack.pop()
            output = transpose_conv(output)

            # Reflect pad on the right/botton if needed to handle odd input dimensions.
            padding = [0, 0, 0, 0]
            if output.shape[-1] != downsample_layer.shape[-1]:
                padding[1] = 1 # Padding right
            if output.shape[-2] != downsample_layer.shape[-2]:
                padding[3] = 1 # Padding bottom
            if sum(padding) != 0:
                output = F.pad(output, padding, "reflect")

            output = torch.cat([output, downsample_layer], dim=1)
            output = conv(output)

        return output.squeeze(1)

In [81]:
print('NEW MODEL!!!!!!!!!!!!')
new_model = Branched_UnetModel(in_chans = 2 , out_chans = 2, chans = 32, num_pool_layers = 4, drop_prob = 0.0, sites = 3)
for key in new_model.state_dict():
    print(key)
    
print('OLD MODEL!!!!!!!!!!!!')
old_model = OldUnetModel(in_chans = 2 , out_chans = 2, chans = 32, num_pool_layers = 4, drop_prob = 0.0)
for key in old_model.state_dict():
    print(key)
    


NEW MODEL!!!!!!!!!!!!
down_sample_layers.0.layers.0.weight
down_sample_layers.0.layers.4.weight
down_sample_layers.1.layers.0.weight
down_sample_layers.1.layers.4.weight
down_sample_layers.2.layers.0.weight
down_sample_layers.2.layers.4.weight
down_sample_layers.3.layers.0.weight
down_sample_layers.3.layers.4.weight
conv.layers.0.weight
conv.layers.4.weight
up_conv.0.layers.0.weight
up_conv.0.layers.4.weight
up_conv.1.layers.0.weight
up_conv.1.layers.4.weight
up_conv.2.layers.0.weight
up_conv.2.layers.4.weight
up_conv.3.0.layers.0.weight
up_conv.3.0.layers.4.weight
up_conv.3.1.weight
up_conv.3.1.bias
up_transpose_conv.0.layers.0.weight
up_transpose_conv.1.layers.0.weight
up_transpose_conv.2.layers.0.weight
up_transpose_conv.3.layers.0.weight
Decoder_up_list.0.0.layers.0.weight
Decoder_up_list.0.0.layers.4.weight
Decoder_up_list.0.1.layers.0.weight
Decoder_up_list.0.1.layers.4.weight
Decoder_up_list.0.2.layers.0.weight
Decoder_up_list.0.2.layers.4.weight
Decoder_up_list.0.3.0.layers.0.w

In [84]:
# Count parameters
total_params_new = np.sum([np.prod(p.shape) for p
                           in new_model.parameters() if p.requires_grad])
print('Total parameters %d' % total_params_new)

total_params_old = np.sum([np.prod(p.shape) for p
                           in old_model.parameters() if p.requires_grad])
print('Total parameters %d' % total_params_old)

Total parameters 7756418
Total parameters 7756418


In [43]:
print(new_model.state_dict()['Decoder_up_list.0.0.layers.0.weight'].shape)
print(new_model.state_dict()['up_conv.0.layers.0.weight'].shape)

print(new_model.state_dict()['Decoder_up_list.0.0.layers.4.weight'].shape)
print(new_model.state_dict()['up_conv.0.layers.4.weight'].shape)

torch.Size([256, 512, 3, 3])
torch.Size([256, 512, 3, 3])
torch.Size([256, 256, 3, 3])
torch.Size([256, 256, 3, 3])


# Testing Nested ModuleLists

In [56]:
outs = nn.ModuleList([nn.ModuleList([nn.Conv2d(3, 3, 3, 1, 1) for j in range(5)]) for i in range(1)])
total_params = np.sum([np.prod(p.shape) for p
                           in outs.parameters() if p.requires_grad])
print('Total parameters %d' % total_params)

Total parameters 420
