#### This notebook shows how to read the fastMRI dataset and apply some simple transformations to the data.

In [None]:
import os
os.chdir("/content/drive/MyDrive/Accelerated MRI Scanning/Repositories/Model")

#***Import Libraries***

In [None]:
%matplotlib inline
import h5py
import numpy as np
from PIL import Image
from PIL import ImageFilter
import matplotlib
import random
import cv2
# from scipy.fft import fft, ifft
import numpy as np
from matplotlib import pyplot as plt

In [None]:
!pip install runstats

In [None]:
import numpy as np
import torch
import os
from collections import OrderedDict
from torch.autograd import Variable
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
import functools
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
writer= SummaryWriter()
import logging

In [None]:
# loading fastmri modules
from fastMRI.data import transforms
from fastMRI.data.mri_data import SliceData
from fastMRI.common.evaluate import nmse, psnr, ssim
from fastMRI.common.subsample import MaskFunc
from torch.utils.data import DataLoader

#***Generator***

In [None]:
class ResnetGenerator(nn.Module):
    def __init__(
            self, input_nc, output_nc, chans=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
            n_blocks=6, gpu_ids=[], use_parallel=True, learn_residual=False, padding_type='reflect'):
        assert (n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.in_chans = input_nc
        self.out_chans = output_nc
        self.chans = ngf
        self.gpu_ids = gpu_ids
        self.use_parallel = use_parallel
        self.learn_residual = learn_residual
        drop_prob = 0.1
        num_pool_layers = 4

        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d


        self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, chans, drop_prob)])
        ch = chans
        for _ in range(num_pool_layers - 1):
            self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob))
            ch *= 2
        
        model = []
        for i in range(6):
            model += [
                ResnetBlock(ch,ch*2, padding_type='reflect', norm_layer=norm_layer, use_dropout=True, use_bias=use_bias)
            ]
        
        self.resnet = nn.Sequential(*model)
        # print(self.resnet)


        self.conv = ConvBlock(ch, ch * 2, drop_prob)


        self.up_conv = nn.ModuleList()
        self.up_transpose_conv = nn.ModuleList()
        for _ in range(num_pool_layers - 1):
            self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
            self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob))
            ch //= 2

        self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
        self.up_conv.append(
            nn.Sequential(
                ConvBlock(ch * 2, ch, drop_prob),
                nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
            )
        )

    def forward(self, input):
      stack = []
      output = input

      # apply down-sampling layers
      for layer in self.down_sample_layers:
          output = layer(output)
          stack.append(output)
          output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)

      output = self.resnet(output)
      output = self.conv(output)

      # apply up-sampling layers
      for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
          downsample_layer = stack.pop()
          output = transpose_conv(output)

          # reflect pad on the right/botton if needed to handle odd input dimensions
          padding = [0, 0, 0, 0]
          if output.shape[-1] != downsample_layer.shape[-1]:
              padding[1] = 1  # padding right
          if output.shape[-2] != downsample_layer.shape[-2]:
              padding[3] = 1  # padding bottom
          if torch.sum(torch.tensor(padding)) != 0:
              output = F.pad(output, padding, "reflect")
          output = torch.cat([output, downsample_layer], dim=1)
          output = conv(output)

      return output

In [None]:
class ResnetBlock(nn.Module):

	def __init__(self, in_chans, out_chans, padding_type, norm_layer, use_dropout, use_bias):
		super(ResnetBlock, self).__init__()

		padAndConv = {
			'reflect': [
                nn.ReflectionPad2d(1),
                nn.Conv2d(in_chans, in_chans, kernel_size=3, bias=use_bias)],
			'replicate': [
                nn.ReplicationPad2d(1),
                nn.Conv2d(in_chans, in_chans, kernel_size=3, bias=use_bias)],
			'zero': [
                nn.Conv2d(in_chans, in_chans, kernel_size=3, padding=1, bias=use_bias )]
		}

		try:
			blocks = padAndConv[padding_type] + [norm_layer(in_chans), nn.LeakyReLU(negative_slope=0.2, inplace=True)] + [nn.Dropout(0.5)] if use_dropout else [] + padAndConv[padding_type] + [ norm_layer(in_chans)]
		except:
			raise NotImplementedError('padding [%s] is not implemented' % padding_type)

		self.conv_block = nn.Sequential(*blocks)


	def forward(self, x):
# 		print(f"\nx: {x.shape}, output: {(self.conv_block(x)).shape}\n")
		out = x + self.conv_block(x)
		return out

In [None]:
class ConvBlock(nn.Module):
    """
    A Convolutional Block that consists of two convolution layers each followed by
    instance normalization, LeakyReLU activation and dropout.
    """

    def __init__(self, in_chans: int, out_chans: int, drop_prob: float):
        """
        Args:
            in_chans: Number of channels in the input.
            out_chans: Number of channels in the output.
            drop_prob: Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.drop_prob = drop_prob

        self.layers = nn.Sequential(
            nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=True),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout2d(drop_prob),
            nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=True),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout2d(drop_prob),
        )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """
        Args:
            image: Input 4D tensor of shape `(N, in_chans, H, W)`.

        Returns:
            Output tensor of shape `(N, out_chans, H, W)`.
        """
        return self.layers(image)


In [None]:
class TransposeConvBlock(nn.Module):
    """
    A Transpose Convolutional Block that consists of one convolution transpose
    layers followed by instance normalization and LeakyReLU activation.
    """

    def __init__(self, in_chans: int, out_chans: int):
        """
        Args:
            in_chans: Number of channels in the input.
            out_chans: Number of channels in the output.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans 

        self.layers = nn.Sequential(
            nn.ConvTranspose2d(
                in_chans, out_chans, kernel_size=2, stride=2, bias=False
            ),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
        )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """
        Args:
            image: Input 4D tensor of shape `(N, in_chans, H, W)`.

        Returns:
            Output tensor of shape `(N, out_chans, H*2, W*2)`.
        """
        # print(f"\nimageCOnvtrans: {image.shape}")
        return self.layers(image)

#***Discriminator***

In [None]:
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=10, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[],
                 use_parallel=True):
        super(NLayerDiscriminator, self).__init__()
        self.gpu_ids = gpu_ids
        self.use_parallel = use_parallel

        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = int(np.ceil((kw - 1) / 2))
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel:
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)

#***Image Pool***

In [None]:
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images


#***Losses***


In [None]:
# TODO: Operation check
def contextual_bilateral_loss(x: torch.Tensor,
                              y: torch.Tensor,
                              weight_sp: float = 0.1,
                              band_width: float = 1.,
                              loss_type: str = 'cosine'):
    """
    Computes Contextual Bilateral (CoBi) Loss between x and y,
        proposed in https://arxiv.org/pdf/1905.05169.pdf.
    Parameters
    ---
    x : torch.Tensor
        features of shape (N, C, H, W).
    y : torch.Tensor
        features of shape (N, C, H, W).
    band_width : float, optional
        a band-width parameter used to convert distance to similarity.
        in the paper, this is described as :math:`h`.
    loss_type : str, optional
        a loss type to measure the distance between features.
        Note: `l1` and `l2` frequently raises OOM.
    Returns
    ---
    cx_loss : torch.Tensor
        contextual loss between x and y (Eq (1) in the paper).
    k_arg_max_NC : torch.Tensor
        indices to maximize similarity over channels.
    """
    LOSS_TYPES = ['cosine', 'l1', 'l2']
    assert x.size() == y.size(), 'input tensor must have the same size.'
    # assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.'

    # spatial loss
    grid = compute_meshgrid(x.shape).to(x.device)
    dist_raw = compute_l2_distance(grid, grid)
    dist_tilde = compute_relative_distance(dist_raw)
    cx_sp = compute_cx(dist_tilde, band_width)

    # feature loss
    if loss_type == 'cosine':
        dist_raw = compute_cosine_distance(x, y)
    elif loss_type == 'l1':
        dist_raw = compute_l1_distance(x, y)
    elif loss_type == 'l2':
        dist_raw = compute_l2_distance(x, y)
    dist_tilde = compute_relative_distance(dist_raw)
    cx_feat = compute_cx(dist_tilde, band_width)
    # combined loss
    cx_combine = (1. - weight_sp) * cx_feat + weight_sp * cx_sp
    k_max_NC, _ = torch.max(cx_combine, dim=2, keepdim=True)
    cx = k_max_NC.mean(dim=1)
    cx_loss = torch.mean(-torch.log(cx + 1e-5))
    return cx_loss

def compute_cx(dist_tilde, band_width):
    w = torch.exp((1 - dist_tilde) / band_width)  # Eq(3)
    cx = w / torch.sum(w, dim=2, keepdim=True)  # Eq(4)
    return cx

def compute_relative_distance(dist_raw):
    dist_min, _ = torch.min(dist_raw, dim=2, keepdim=True)
    dist_tilde = dist_raw / (dist_min + 1e-5)
    return dist_tilde

def compute_cosine_distance(x, y):
    # mean shifting by channel-wise mean of `y`.
    y_mu = y.mean(dim=(0, 2, 3), keepdim=True)
    x_centered = x - y_mu
    y_centered = y - y_mu

    # L2 normalization
    x_normalized = F.normalize(x_centered, p=2, dim=1)
    y_normalized = F.normalize(y_centered, p=2, dim=1)

    # channel-wise vectorization
    N, C, *_ = x.size()
    x_normalized = x_normalized.reshape(N, C, -1)  # (N, C, H*W)
    y_normalized = y_normalized.reshape(N, C, -1)  # (N, C, H*W)

    # consine similarity
    cosine_sim = torch.bmm(x_normalized.transpose(1, 2),
                           y_normalized)  # (N, H*W, H*W)
    # convert to distance
    dist = 1 - cosine_sim
    return dist


# TODO: Considering avoiding OOM.
def compute_l1_distance(x: torch.Tensor, y: torch.Tensor):
    N, C, H, W = x.size()
    x_vec = x.view(N, C, -1)
    y_vec = y.view(N, C, -1)
    dist = x_vec.unsqueeze(2) - y_vec.unsqueeze(3)
    dist = dist.sum(dim=1).abs()
    dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
    dist = dist.clamp(min=0.)
    return dist

# TODO: Considering avoiding OOM.
def compute_l2_distance(x, y):
    N, C, H, W = x.size()
    x_vec = x.view(N, C, -1)
    y_vec = y.view(N, C, -1)
    x_s = torch.sum(x_vec ** 2, dim=1)
    y_s = torch.sum(y_vec ** 2, dim=1)
    A = y_vec.transpose(1, 2) @ x_vec
    dist = y_s - 2 * A + x_s.transpose(0, 1)
    dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
    dist = dist.clamp(min=0.)
    return dist

def compute_meshgrid(shape):
    N, C, H, W = shape
    rows = torch.arange(0, H, dtype=torch.float32) / (H + 1)
    cols = torch.arange(0, W, dtype=torch.float32) / (W + 1)

    feature_grid = torch.meshgrid(rows, cols)
    feature_grid = torch.stack(feature_grid).unsqueeze(0)
    feature_grid = torch.cat([feature_grid for _ in range(N)], dim=0)

    return feature_grid

In [None]:
class ContentLoss:
	def __init__(self, loss):
		self.criterion = loss
			
	def get_loss(self, fakeIm, realIm):
		return self.criterion(fakeIm, realIm)

class PerceptualLoss():
  def contentFunc(self):
    conv_3_3_layer = 14
    cnn = models.vgg19(pretrained=True).features
    cnn = cnn.cuda()
    model = nn.Sequential()
    model = model.cuda()
    for i,layer in enumerate(list(cnn)):
      model.add_module(str(i),layer)
      if i == conv_3_3_layer:
        break
    return model

  def __init__(self, loss):
    self.criterion = loss
    self.contentFunc = self.contentFunc()

  def get_loss(self, fakeIm, realIm):
    new_fakeIm = fakeIm[:,:, :, :] * torch.ones(3)[None,:, None, None].to(torch.device('cuda') )
    new_realIm = realIm[:,:, :, :] * torch.ones(3)[None,:, None, None].to(torch.device('cuda') )

    f_fake = self.contentFunc.forward(new_fakeIm)

    f_real = self.contentFunc.forward(new_realIm)
    f_real_no_grad = f_real
    loss = self.criterion(f_fake, f_real_no_grad)
    return loss

class GANLoss(nn.Module):
	def __init__(
			self, use_l1=True, target_real_label=1.0,
			target_fake_label=0.0, tensor=torch.FloatTensor):
		super(GANLoss, self).__init__()
		self.real_label = target_real_label
		self.fake_label = target_fake_label
		self.real_label_var = None
		self.fake_label_var = None
		self.Tensor = tensor
		if use_l1:
			self.loss = nn.L1Loss()
		else:
			self.loss = nn.BCELoss()

	def get_target_tensor(self, input, target_is_real):
		target_tensor = None
		if target_is_real:
			create_label = ((self.real_label_var is None) or
							(self.real_label_var.numel() != input.numel()))
			if create_label:
				real_tensor = self.Tensor(input.size()).fill_(self.real_label)
				self.real_label_var = Variable(real_tensor, requires_grad=False)
			target_tensor = self.real_label_var
		else:
			create_label = ((self.fake_label_var is None) or
							(self.fake_label_var.numel() != input.numel()))
			if create_label:
				fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
				self.fake_label_var = Variable(fake_tensor, requires_grad=False)
			target_tensor = self.fake_label_var
		return target_tensor

	def __call__(self, input, target_is_real):
		target_tensor = self.get_target_tensor(input, target_is_real).to(torch.device('cuda'))
		return self.loss(input, target_tensor)
    

class DiscLoss:
	def name(self):
		return 'DiscLoss'

	def __init__(self):
		self.criterionGAN = GANLoss(use_l1=False)
		self.fake_AB_pool = ImagePool(pool_size=50)
		
	def get_g_loss(self,net, realA, fakeB):
		# First, G(A) should fake the discriminator
		pred_fake = net.forward(fakeB)
		return self.criterionGAN(pred_fake, 1)
		
	def get_loss(self, net, realA, fakeB, realB):
		# Fake
		# stop backprop to the generator by detaching fake_B
		# Generated Image Disc Output should be close to zero
		self.pred_fake = net.forward(fakeB)
		self.loss_D_fake = self.criterionGAN(self.pred_fake, 0)

		# Real
		self.pred_real = net.forward(realB)
		self.loss_D_real = self.criterionGAN(self.pred_real, 1)

		# Combined loss
		self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
		return self.loss_D
		
# class DiscLossLS(DiscLoss):
# 	def name(self):
# 		return 'DiscLossLS'

# 	def __init__(self):
# 		super(DiscLoss, self).__init__()
# 		# DiscLoss.initialize(self, opt, tensor)
# 		self.criterionGAN = GANLoss(use_l1=True, tensor=tensor)
		
# 	def get_g_loss(self,net, realA, fakeB):
# 		return DiscLoss.get_g_loss(self,net, realA, fakeB)
		
# 	def get_loss(self, net, realA, fakeB, realB):
# 		return DiscLoss.get_loss(self, net, realA, fakeB, realB)

        

In [None]:

# class DiscLossWGANGP(DiscLossLS):
#     def name(self):
#       return 'DiscLossWGAN-GP'

#     def __init__():
#       super(DiscLossWGANGP, self).__init__()
#       # DiscLossLS.initialize(self, opt, tensor)
#       self.LAMBDA = 10
        
#     def get_g_loss(net, realA, fakeB):
#         # First, G(A) should fake the discriminator
#         D_fake = net.forward(fakeB)
#         return -D_fake.mean()

#     def calc_gradient_penalty(self, netD, real_data, fake_data):
#         alpha = torch.rand(1, 1)
#         alpha = alpha.expand(real_data.size())
#         alpha = alpha.cuda()

#         interpolates = alpha * real_data + ((1 - alpha) * fake_data)

#         interpolates = interpolates.cuda()
#         interpolates = Variable(interpolates, requires_grad=True)

#         disc_interpolates = netD.forward(interpolates)

#         gradients = autograd.grad(
#             outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
#             create_graph=True, retain_graph=True, only_inputs=True
#         )[0]

#         gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
#         return gradient_penalty

#     def get_loss( net, realA, fakeB, realB):
#         D_fake = net.forward(fakeB)
#         D_fake = D_fake.mean()

#         # Real
#         D_real = net.forward(realB)
#         D_real = D_real.mean()
#         # Combined loss
#         loss_D = D_fake - D_real
#         gradient_penalty = calc_gradient_penalty(net, realB.data, fakeB.data)
#         return loss_D + gradient_penalty

		

#***Init Weights***

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0)
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


#***Initialize Model***

In [None]:
channel_rate = 64
# Note the image_shape must be multiple of patch_shape
image_shape = (320, 320, 1)
patch_shape = (channel_rate, channel_rate, 3)

ngf = 32
ndf = 32
input_nc = 1
output_nc = 1

use_dropout=True
gpu_ids=[0]
use_parallel=False
learn_residual=False
use_sigmoid= True


norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=12,
                               gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual)
netG.cuda(gpu_ids[0])
# netG.apply(weights_init)

n_layers_D=12
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
                                   gpu_ids=gpu_ids, use_parallel=use_parallel)
netD.cuda(gpu_ids[0])
netD.apply(weights_init)

NLayerDiscriminator(
  (model): Sequential(
    (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
    (3): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
    (6): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
    (9): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
    (12): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running

In [None]:
# print(ConvBlock(1, 1, 0.0))
# print(ResnetBlock(1,1,"reflect",functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True), True, nn.InstanceNorm2d ))

#***Load Data***

In [None]:
class DataTransform:
    def __init__(self, mask_func, resolution, which_challenge, use_seed=True):
        self.resolution = resolution
        self.mask_func = mask_func
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace = transforms.to_tensor(kspace)
        
        seed = tuple(map(ord, fname))
        masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
      
        image = transforms.ifft2(masked_kspace)a
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        image = transforms.complex_abs(image)
        image, mean, std = transforms.normalize_instance(image)
        image = image.clamp(-6, 6)
        target = transforms.to_tensor(target)
        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
        return image, target, mean, std
        #return image, mean, std, fname, slice

In [None]:
import pathlib

root = pathlib.Path('/content/drive/MyDrive/Accelerated MRI Scanning/Repositories/fastMRImaster/Dataset/knee_sc/singlecoil_train/')
root_val=pathlib.Path('/content/drive/MyDrive/Accelerated MRI Scanning/Repositories/fastMRImaster/Dataset/knee_sc/singlecoil_val/')
batch_size=1
num_workers=8

def create_dataset(root, resolution=320, center_fractions=[0.08], accelerations=[4]):
    
    mask_func = MaskFunc(center_fractions, accelerations)
    data = SliceData(
        root = root,
        transform = DataTransform(mask_func, resolution,'singlecoil'),
        sample_rate = 1.0,
        challenge = 'singlecoil'
    )
    return data

data_train=create_dataset(root)
data_val=create_dataset(root_val)
    
def create_dataloader(data, batch_size, num_workers):
    dataloader = DataLoader(
        dataset = data,
        batch_size = batch_size,
        num_workers = num_workers,
    )
    return dataloader


data_load_train=create_dataloader(data_train,batch_size, num_workers)

  cpuset_checked))


#***Optimizer***

In [None]:
beta1=0.5
lr=0.001
device = torch.device('cuda') 

In [None]:

optimizer_G = torch.optim.Adam( netG.parameters(), lr=lr, betas=(beta1, 0.999) )
optimizer_D = torch.optim.Adam( netD.parameters(), lr=lr, betas=(beta1, 0.999) )

def forward(netG,input_A,input_B):
		real_A = Variable(input_A)
		fake_B = netG.forward(real_A)
		real_B = Variable(input_B)


def backward_D(netD, real_A, fake_B, real_B, eval=False):
  discLoss=DiscLoss()
  loss_D = discLoss.get_loss(netD, real_A, fake_B, real_B)

  if not eval:
    #retaingraph=true save the gradients of prev batch(its in use as we are
    #calling backward() multiple times on same batch multiple times look optimize_parameters func )
    loss_D.backward(retain_graph=True)

  return loss_D.item()
  
def backward_G(netD, real_A, fake_B, real_B,eval=False):
    discLoss=DiscLoss()
    # per_loss=PerceptualLoss(nn.MSELoss())

    contentLoss = ContentLoss(nn.L1Loss())
    #print(fake_B.shape)
    #print(real_A.shape)
    loss_G_GAN = discLoss.get_g_loss(netD, real_A, fake_B)
    # Second, G(A) = B
    loss_G_Content = contentLoss.get_loss(fake_B, real_B) *100
    # loss_G_percept= per_loss.get_loss(fake_B, real_B)
    loss_G_percept = contextual_bilateral_loss(fake_B, real_B, loss_type = "cosine")
                                                                                                              # loss_G_percept,_ = CX_loss(real_B, fake_B)
                                                                                                              
    loss_G = loss_G_GAN + loss_G_Content +loss_G_percept*50
    
    if not eval:
      loss_G.backward()
    return loss_G.item()
    
def optimize_parameters(netD,netG, real_A, fake_B, real_B,optimizer_G,optimizer_D):
    forward(netG,real_A,real_B)
    criticUpdates=5
    lossd=[]
    for iter_d in range(criticUpdates):
        optimizer_D.zero_grad()
        lossd.append(backward_D(netD, real_A, fake_B, real_B))
        optimizer_D.step()
    loss_D=torch.mean(torch.tensor(lossd))
    optimizer_G.zero_grad()
    loss_G= backward_G(netD, real_A, fake_B, real_B)
    optimizer_G.step()
    
    return loss_G,loss_D

#***PostProcessing***

In [None]:
points = [
          [120.53896104 , 78.63419913], [201.05844156 , 78.63419913], [237.42207792 , 79.5       ], [ 81.57792208 , 76.03679654], [122.27056277 ,118.46103896], 
          [199.32683983 ,116.72943723], [240.88528139 ,117.5952381 ], [198.46103896 ,199.84632035], [119.67316017 ,197.24891775], [ 82.44372294 ,196.38311688], 
          [277.24891775 ,198.98051948], [201.05844156 ,237.94155844], [ 82.44372294 ,237.94155844], [ 39.15367965 ,120.19264069], [279.84632035 ,116.72943723],
          [274.65151515 , 80.36580087], [201.92424242 , 37.94155844], [117.07575758 , 41.4047619 ], [117.94155844 ,274.30519481], [197.5952381  ,281.23160173],
          [240.01948052 ,242.27056277], [278.98051948 ,241.4047619 ]
    ]

In [None]:
import cv2
import numpy as np
from matplotlib import pyplot as plt
import matplotlib
from math import exp, pow

class IdealNotchFilter:
    def __init__(self):
        pass
    
    def apply_filter(self, fshift, points, d0):
        # print(fshift.shape)
        m = fshift.shape[0]
        n = fshift.shape[1]
        for u in range(m):
            for v in range(n):
                # print(u, v)
                for d in range(len(points)):
                    u0 = points[d][0]
                    v0 = points[d][1]
                    u0, v0 = v0, u0
                    d1 = pow(pow(u - u0, 2) + pow(v - v0, 2), 1)
                    d2 = pow(pow(u + u0, 2) + pow(v + v0, 2), 1)
                    if d1 <= d0 or d2 <= d0:
                        fshift[u][v] *= 0.0
        f_ishift = np.fft.ifftshift(fshift)
        img_back = np.fft.ifft2(f_ishift)
        img_back = np.abs(img_back)
        # plt.imshow(img_back)
        # plt.show()
        # matplotlib.image.imsave(path, img_back, cmap = "gray")
        return img_back

class GaussianNotchFilter:
    def __init__(self):
        pass
    
    def apply_filter(self, fshift, points, d0):
        m = fshift.shape[0]
        n = fshift.shape[1]
        for u in range(m):
            for v in range(n):
                for d in range(len(points)):
                    u0 = points[d][0]
                    v0 = points[d][1]
                    u0, v0 = v0, u0
                    d1 = pow(pow(u - u0, 2) + pow(v - v0, 2), 0.5)
                    d2 = pow(pow(u + u0, 2) + pow(v + v0, 2), 0.5)
                    fshift[u][v] *= (1 - exp(-0.5 * (d1 * d2 / pow(d0, 2))))

        f_ishift = np.fft.ifftshift(fshift)
        img_back = np.fft.ifft2(f_ishift)
        img_back = np.abs(img_back)
        # matplotlib.image.imsave(path, img_back, cmap = "gray")
        return img_back

#***Save or Load Model***

In [None]:
weights_dir = "/content/drive/MyDrive/Accelerated MRI Scanning/Repositories/FastMRI-Challenge/Phase 4/weights"
#stores the weights in  weights folder
def save_network(network, network_label, epoch_label, gpu_ids):
  # return
  save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
  save_path = os.path.join(weights_dir, save_filename)
  torch.save(network.cpu().state_dict(), save_path)
  if len(gpu_ids) and torch.cuda.is_available():
      network.cuda(device=gpu_ids[0])

# run this before model initialization
def load_network(network, network_label, epoch_label):
  save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
  save_path = os.path.join(weights_dir, save_filename)
  network.load_state_dict(torch.load(save_path))

#***Main***

In [None]:
import random
import cv2
# from scipy.fft import fft, ifft
import numpy as np

In [None]:

def generate_images(netG, test_input, tar, ep=-1, ep_iter=-1):

  netG.eval()
  prediction = netG.forward(test_input)
  matplotlib.image.imsave("./Phase 4/images/Pred_"+str(ep)+"_"+str(ep_iter)+".png", prediction[0,0].detach().cpu().numpy(), cmap = "gray")

  # fun(tar)
  # fun(prediction, True)  
  
  # f = np.fft.fft2(prediction[0,0].detach().cpu().numpy())
  # fshift = np.fft.fftshift(f)
  # freq = 20 * np.log(np.abs(fshift))
  
  # tmp12 = IdealNotchFilter().apply_filter(fshift, points, 121.0)

  plt.figure(figsize=(15,15))
  display_list = []
  display_list.append((test_input[0,0]).detach().cpu().numpy())
  display_list.append((tar[0,0]).detach().cpu().numpy())
  display_list.append((prediction[0,0]).detach().cpu().numpy())
  gt = display_list[1]
  pred = display_list[2]
  losses = {"NMSE":nmse(gt, pred),"SSIM": ssim(gt.reshape(1,320,320), pred.reshape(1,320,320)),"PSNR": psnr(gt, pred)}

  print(losses, end = "\n\n") 
  title = ['Input Image', 'Ground Truth', 'Predicted Image', 'Reconstruction Error']
  display_list.append((display_list[1]-display_list[2]))

  for i in range(0, len(title)):
    plt.subplot(1, 4, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i] * 0.5 + 0.5, cmap="gray")
    plt.axis('off')
  if ep > -1 and ep_iter> -1:
    plt.savefig('./Phase 4/images/'+str(ep)+'_'+str(ep_iter)+'.png', dpi=300,  bbox_inches='tight')
  plt.show()
  netG.train()

In [None]:
report_interval=100

In [None]:
def evaluate( epoch, netG, netD, data_loader,opt_g, opt_d, writer):
    netG.eval()
    losses = []
    start = time.perf_counter()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            input, target, mean, std, norm = data
            input = input.unsqueeze(1).to(device)
            target = target.to(device)
            output = netG(input).squeeze(1)

            mean = mean.unsqueeze(1).unsqueeze(2).to(device)
            std = std.unsqueeze(1).unsqueeze(2).to(device)
#            target = target * std + mean
#            output = output * std + mean

#            norm = norm.unsqueeze(1).unsqueeze(2).to(args.device)
            loss = bakward_G(netD, input, output, target, eval= True)
            losses.append(loss.item())
        writer.add_scalar('Dev_Loss', np.mean(losses), epoch)
    return np.mean(losses), time.perf_counter() - start

In [None]:
device = torch.device('cuda') 
epoch_count=5
batchSize=1

# gloss = []
# dloss = []
# trainlo
import time
def train(data_loader, netG,netD, writer):
  #dataset = data_loader.load_data()
  dataset_size = len(data_loader)
  print('#training images = %d' % dataset_size)
  total_steps = 0
  
  train_time = 0
  for epoch in range(epoch_count):
    epoch_start_time = time.time()
    epoch_iter = 0
    start_epoch = start_iter = time.perf_counter()
    global_step = epoch * len(data_loader)
    print(enumerate(data_loader))
    for i, data in enumerate(data_loader):
      iter_start_time = time.time()
      total_steps += batchSize
      epoch_iter += batchSize
      input, target, mean, std = data
      
      #print(input.unsqueeze(1).shape)
      #input-input.expand(batchSize,320,320,1)
      #target=target.expand(batchSize,320,320,1)
      input = input.unsqueeze(1).to(device)
      #print(input.shape)
      target = target.unsqueeze(1).to(device)
      output = netG(input)#.squeeze(1)
      #print(target.shape)
      loss_g,loss_d=optimize_parameters(netD,netG,input,output,target,optimizer_G,optimizer_D)
      
      
      TrainLoss=loss_g+loss_d
      writer.add_scalar('TrainLoss',TrainLoss,global_step+i)
      avg_loss = 0.99 * avg_loss + 0.01 * TrainLoss if i > 0 else TrainLoss.item()
      writer.add_scalar('TrainLoss', avg_loss, global_step + i)

      if i % report_interval == 0:
          logging.info(
              f'Epoch = [{epoch:3d}/{epoch_count:3d}] '
              f'Iter = [{i:4d}/{len(data_loader):4d}] '
              f'Loss = {TrainLoss:.4g} Avg Loss = {avg_loss:.4g} '
              f'Time = {time.perf_counter() - start_iter:.4f}s',
          )
        
      start_iter = time.perf_counter()
      train_time= time.perf_counter() - start_epoch
      
      print(f'Epoch = [{epoch:3d}/{epoch_count:3d}], ',f'Iter = [{i:4d}/{len(data_loader):4d}], ',f'Train Loss = {TrainLoss:.4g}, Avg Loss = {avg_loss:.4g}, ', f'Time = {time.perf_counter() - start_iter:.4f}s')
      

      # if total_steps % report_interval == 0:
      if i% 50==0:
        # print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
        # save_network(netG, 'G', 'latest', gpu_ids) #gpu_ids daal dena
        # save_network(netD, 'D', 'latest', gpu_ids)
        print('Visualising images :- ')
        generate_images(netG, input, target, epoch, epoch_iter)
    
    
    e_loss , tme = evaluate(epoch, netG, netD, data_loader, optimizer_G, optimizer_D, writer)
    print(f"\n\nLoss: {e_loss}, time: {tme}")
    
    # if epoch % report_interval == 0:
    print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
    save_network(netG, 'G', 'latest', gpu_ids)
    save_network(netD, 'D', 'latest', gpu_ids)
    save_network(netG, 'G', epoch, gpu_ids)
    save_network(netD, 'D', epoch, gpu_ids)


In [None]:
# !pip install torch==1.4.0

In [None]:
torch.cuda.empty_cache()

In [None]:
# LOAD_NETWORK
train(data_load_train,netG,netD, writer)

In [None]:
a,b,c,d = [1,2,3,4]
print(a,b,c,d)

#***Testing***

In [None]:
# import torch

# noise = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor(3.20))
# noise = noise.sample()

In [None]:

# def generate_images(netG, test_input, tar):
#   # the training=True is intentional here since
#   # we want the batch statistics while running the netG
#   # on the test dataset. If we use training=False, we will get
#   # the accumulated statistics learned from the training dataset
#   # (which we don't want)
#   prediction = netG(test_input, training=True)
#   plt.figure(figsize=(15,15))

#   display_list = [test_input[0], tar[0], prediction[0]]
#   title = ['Input Image', 'Ground Truth', 'Predicted Image']

#   for i in range(3):
#     plt.subplot(1, 3, i+1)
#     plt.title(title[i])
#     # getting the pixel values between [0, 1] to plot it.
#     plt.imshow(display_list[i] * 0.5 + 0.5)
#     plt.axis('off')
#   plt.show()

In [None]:
import h5py

In [None]:
flname =  pathlib.Path('/content/drive/MyDrive/Accelerated MRI Scanning/Repositories/fastMRImaster/Dataset/knee_sc/singlecoil_test/')
batch_size=1
num_workers=2
test_dataset = create_dataset(flname)
test_dataset_dataloader = create_dataloader(test_dataset,batch_size, num_workers)


# test_file = h5py.File(flname, "r+")
# test_kspace = test_file['reconstruction'][:]

In [None]:
# for i, x in enumerate(test_dataset_dataloader):
#   print(i, x)

In [None]:
netGload =  ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=12, gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual)
netGload.cuda(gpu_ids[0])
# netGload.apply(weights_init)
load_network(netGload, "G", 4);

In [None]:
data_load_val=create_dataloader(data_val,batch_size, num_workers)

In [None]:
device = torch.device('cuda')
for i, data in enumerate(data_load_val):
  if i % 20==0:
    input, target, mean, std = data
    input = input.unsqueeze(1).to(device)
    target = target.unsqueeze(1).to(device)
    print(input.shape, target.shape)
    generate_images(netGload, input, target, i, i)
    

Pre-processing 

In [None]:
matplotlib.image.imsave('pre.png', sampled_image_abs4 , cmap="gray")


#plt.savefig('foo.png')


# Open an already existing image
imageObject = Image.open("pre.png");
imageObject.show();


# Apply sharp filter
sharpened1 = imageObject.filter(ImageFilter.SHARPEN);
sharpened2 = sharpened1.filter(ImageFilter.SHARPEN);

# Show the sharpened images

sharpened1.show();
sharpened2.show();

plt.imshow(sharpened4)


show_slice(sharpened2 , cmap='gray')


t1=imageObject.filter(ImageFilter.EDGE_ENHANCE);
plt.imshow(t1)


t2=imageObject.filter(ImageFilter.EDGE_ENHANCE_MORE);
plt.imshow(t2)


t3=imageObject.filter(ImageFilter.SMOOTH_MORE);
plt.imshow(t3)

In [None]:
#Three lines to make our compiler able to draw:
import sys
# matplotlib.use('Agg')

import matplotlib.pyplot as plt
import numpy as np

arr = np.ones((20,20))

#Two  lines to make our compiler able to draw:
# plt.savefig(sys.stdout.buffer)
# sys.stdout.flush()




In [None]:
arr[6:14,6:14] = 3
arr[3:4] = 2
arr[19]=5
print(arr)

In [None]:

plt.plot(arr, "k+")
plt.show()
