In [1]:
import sys
sys.path.append('/share/gpu0/jjwhit/rcGAN/')
import numpy as np
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from models.lightning.mmGAN import mmGAN
from models.archs.mri.generator import UNetModel
from models.archs.mri.discriminator import DiscriminatorModel
from data.lightning.MassMappingDataModule import MMDataTransform

import yaml
import json
import types
import torch
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



In [20]:
import torch
from torch import nn


class ResidualBlock(nn.Module):
    """
    A Convolutional Block that consists of two convolution layers each followed by
    instance normalization, relu activation and dropout.
    """

    def __init__(self, in_chans, out_chans, batch_norm=True):
        """
        Args:
            in_chans (int): Number of channels in the input.
            out_chans (int): Number of channels in the output.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.batch_norm = batch_norm

        if self.in_chans != self.out_chans:
            self.out_chans = self.in_chans

        # self.norm = nn.BatchNorm2d(self.out_chans)
        self.conv_1_x_1 = nn.Conv2d(self.in_chans, self.out_chans, kernel_size=(1, 1))
        self.layers = nn.Sequential(
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(self.in_chans, self.out_chans, kernel_size=(3, 3), padding=1),
            # nn.BatchNorm2d(self.out_chans),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(self.in_chans, self.out_chans, kernel_size=(3, 3), padding=1),
        )

    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]

        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        output = input

        return self.layers(output) + self.conv_1_x_1(output)


class FullDownBlock(nn.Module):
    def __init__(self, in_chans, out_chans):
        """
        Args:
            in_chans (int): Number of channels in the input.
            out_chans (int): Number of channels in the output.
        """
        super().__init__()
        self.in_chans = in_chans
        self.out_chans = out_chans

        self.downsample = nn.Sequential(
            nn.AvgPool2d(kernel_size=(2, 2), stride=2),
            nn.Conv2d(self.in_chans, self.out_chans, kernel_size=(3, 3), padding=1),
            nn.InstanceNorm2d(self.out_chans),
            nn.LeakyReLU(negative_slope=0.2),
        )

    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]

        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """

        return self.downsample(input)  # self.resblock(self.downsample(input))

    def __repr__(self):
        return f'AvgPool(in_chans={self.in_chans}, out_chans={self.out_chans}\nResBlock(in_chans={self.out_chans}, out_chans={self.out_chans}'


class DiscriminatorModel(nn.Module):
    def __init__(self, in_chans, out_chans, z_location=None, model_type=None, mbsd=False):
        """
        Args:
            in_chans (int): Number of channels in the input to the U-Net model.
            out_chans (int): Number of channels in the output to the U-Net model.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = 2
        self.z_location = z_location
        self.model_type = model_type
        self.mbsd = mbsd

        # CHANGE BACK TO 16 FOR MORE
        self.initial_layers = nn.Sequential(
            nn.Conv2d(self.in_chans, 32, kernel_size=(3, 3), padding=1),  # 384x384
            nn.LeakyReLU()
        )

        self.encoder_layers = nn.ModuleList()
        self.encoder_layers += [FullDownBlock(32, 64)]  # 64x64
        self.encoder_layers += [FullDownBlock(64, 128)]  # 32x32
        self.encoder_layers += [FullDownBlock(128, 256)]  # 16x16
        self.encoder_layers += [FullDownBlock(256, 512)]  # 8x8
        self.encoder_layers += [FullDownBlock(512, 512)]  # 4x4
        self.encoder_layers += [FullDownBlock(512, 512)]  # 2x2
        self.encoder_layers += [FullDownBlock(512, 512)]
        self.encoder_layers += [FullDownBlock(512, 512)]

        self.dense = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1),
        )

    def forward(self, input, y):
        output = torch.cat([input, y], dim=1)
        print('input.shape: ', input.shape)
        print('y.shape: ', y.shape)
        print('output.shape: ', output.shape)
        output = self.initial_layers(output)
        print('output.shape: ', output.shape)
        # Apply down-sampling layers
        print('Inside layers' )
        for layer in self.encoder_layers:
            print('output.shape: ', output.shape)
            output = layer(output)
        print('BEfore DEnse' )
        print('output.shape: ', output.shape)
        return self.dense(output)


In [21]:

x = torch.zeros(size=(5, 2, 1024, 1024), device=device).type(torch.FloatTensor)
y = torch.zeros(size=(5, 2, 1024, 1024), device=device).type(torch.FloatTensor)


In [22]:


discriminator = DiscriminatorModel(in_chans=4, out_chans=2)



In [23]:

real_pred = discriminator(input=x, y=y)


input.shape:  torch.Size([5, 2, 1024, 1024])
y.shape:  torch.Size([5, 2, 1024, 1024])
output.shape:  torch.Size([5, 4, 1024, 1024])
output.shape:  torch.Size([5, 32, 1024, 1024])
Inside layers
output.shape:  torch.Size([5, 32, 1024, 1024])
output.shape:  torch.Size([5, 64, 512, 512])
output.shape:  torch.Size([5, 128, 256, 256])
output.shape:  torch.Size([5, 256, 128, 128])
output.shape:  torch.Size([5, 512, 64, 64])
output.shape:  torch.Size([5, 512, 32, 32])
output.shape:  torch.Size([5, 512, 16, 16])
output.shape:  torch.Size([5, 512, 8, 8])
BEfore DEnse
output.shape:  torch.Size([5, 512, 4, 4])


In [24]:
real_pred

tensor([[-0.1461],
        [-0.1461],
        [-0.1461],
        [-0.1461],
        [-0.1461]], grad_fn=<AddmmBackward0>)