In [1]:
import math

import random
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable, grad
import torchvision.utils
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity

from modules.dataset import LoadMRI, DatasetReconMRI
from modules.dataset import build_loaders
from modules.utils import kspace2image, image2kspace, complex2pseudo, pseudo2real, pseudo2complex, imsshow
from modules.solver import Solver
import modules.kspace_pytorch as cl

In [2]:
%load_ext autoreload
%autoreload 2

# Preprocess

In [3]:
dataset = LoadMRI('./cine.npz')
CINE_INDEX = random.randint(0, 199)

In [4]:
data_test = DatasetReconMRI(dataset, acc=6, num_center_lines=20)

In [None]:
img_und, und_mask, img_gt = data_test[CINE_INDEX]
print(f"img_und: {img_und.shape}, und_mask: {und_mask.shape}, img_gt: {img_gt.shape}")
print(f"img_und: {img_und.dtype}, und_mask: {und_mask.dtype}, img_gt: {img_gt.dtype}")

img_und = np.abs(pseudo2real(img_und))
img_gt = np.abs(pseudo2real(img_gt))
# Try different index range in [0, 1000) to see different images in dataset.
imsshow(img_und, num_col=5, cmap='gray', is_colorbar=True)
imsshow(und_mask, num_col=5, cmap='gray', is_colorbar=True)
imsshow(img_gt, num_col=10, cmap='gray', is_colorbar=True)

In [None]:
TRAIN_INDICES = np.arange(0, 112)
VAL_INDICES = np.arange(112, 140)
TEST_INDICES = np.arange(140, 200)

train_loader, val_loader, test_loader = build_loaders(
    dataset, TRAIN_INDICES, VAL_INDICES, TEST_INDICES,
    acc = 8,
    num_center_lines=12,
    batch_size=10  # Reduce this number if your computer does not has large gpu vram
)
print(f"Number of batches for train/val/test: {len(train_loader)}/{len(val_loader)}/{len(test_loader)}")

# Network

## Network Framework

In [7]:
class MRIReconstructionFramework(nn.Module):
    def __init__(self, recon_net: nn.Module):
        super().__init__()
        self.recon_net = recon_net

    def forward(self, x_und, mask):
        B, C, T, H, W = x_und.shape
        im_recon = self.recon_net(x_und)
        return im_recon

## Cascading

In [6]:
class MRIReconstructionFramework(nn.Module):
    def __init__(self, recon_net: nn.Module):
        super().__init__()
        self.recon_net = recon_net

    def forward(self, x_und, mask):
        B, C, T, H, W = x_und.shape
        x_k = image2kspace(pseudo2complex(x_und))
#         x_k = x_k * mask
#         x_dc = kspace2image(x_k)
        x_dc = complex2pseudo(x_k)
        
        im_recon = self.recon_net(x_und, x_dc)
        return im_recon

In [39]:
class DC(nn.Module):
    def __init__(self):
        super(DC, self).__init__()
        self.lambda_num = nn.Parameter(torch.ones(4, 20, 192, 192))
#         self.lambda_num = 100
        
    def forward(self, x, dc_k):
        dc_k = pseudo2complex(dc_k)
        x = image2kspace(x)
        x = pseudo2complex(x)
        output = (self.lambda_num * dc_k + x) / (self.lambda_num + 1)
        output = kspace2image(output)
#         output = x
        output = complex2pseudo(output)
        
#         test = pseudo2real(output.detach().cpu().numpy())
#         img_und = np.abs(test[0])
#         imsshow(img_und, num_col=1, cmap='gray', is_colorbar=True)
        return output

class MultiLayerCNN(nn.Module):
    def __init__(self, n_hidden=64):
        super().__init__()
        self.conv1 = nn.Conv3d(2, n_hidden, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(n_hidden, n_hidden, kernel_size=3, padding=1)
        self.conv3 = nn.Conv3d(n_hidden, n_hidden, kernel_size=3, padding=1)
        self.conv4 = nn.Conv3d(n_hidden, n_hidden, kernel_size=3, padding=1)
        self.conv5 = nn.Conv3d(n_hidden, 2, kernel_size=3, padding=1)

        self.relu = nn.ReLU()
        self.drop = nn.Dropout3d(p=0.1)

    def forward(self, im_und):
        """
        - im_und: tensor[B, C=2, H, W]
        """
        x = self.relu(self.drop(self.conv1(im_und)))
        x = self.relu(self.drop(self.conv2(x)))
        x = self.relu(self.drop(self.conv3(x)))
        x = self.relu(self.drop(self.conv4(x)))
        diff = self.conv5(x)
        return diff


class ReconstructionNet(nn.Module):
    def __init__(self):
        super(ReconstructionNet, self).__init__()
        
        # Define the 3D CNN network
        self.cnn = MultiLayerCNN()
        self.dc = DC()
        # Define the output layer
        
    def forward(self, x, x_dc):
        # Pass the input through the 3D CNN network
        x1 = self.cnn(x)
#         x = self.dc(x, x_dc)
        x2 = self.dc(x1, x_dc)
        x3 = self.cnn(x2)
        x2 = self.dc(x3, x_dc)
        x3 = self.cnn(x2)
        x = x3 + x
        # Pass the output of the 3D CNN network through the output layer
        
        return x



## Unet

In [78]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.Dropout3d(p=0.1),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.Dropout3d(p=0.1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.mpconv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = nn.functional.pad(x1, (diffX // 2, diffX - diffX // 2,
                                    diffY // 2, diffY - diffY // 2,
                                    diffZ // 2, diffZ - diffZ // 2,))
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.bilinear = bilinear

        self.conv1 = DoubleConv(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.conv_out = nn.Conv3d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.conv_out(x)
        return logits


## MultiCNN with Dropout

In [8]:
# Network definition
class MultiLayerCNN(nn.Module):
    def __init__(self, n_hidden=64):
        super().__init__()
        self.conv1 = nn.Conv3d(2, n_hidden, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(n_hidden, n_hidden, kernel_size=3, padding=1)
        self.conv3 = nn.Conv3d(n_hidden, n_hidden, kernel_size=3, padding=1)
        self.conv4 = nn.Conv3d(n_hidden, n_hidden, kernel_size=3, padding=1)
        self.conv5 = nn.Conv3d(n_hidden, 2, kernel_size=3, padding=1)

        self.relu = nn.ReLU()
        self.drop = nn.Dropout3d(p=0.1)

    def forward(self, im_und):
        """
        - im_und: tensor[B, C=2, H, W]
        """
        x = self.relu(self.drop(self.conv1(im_und)))
        x = self.relu(self.drop(self.conv2(x)))
        x = self.relu(self.drop(self.conv3(x)))
        x = self.relu(self.drop(self.conv4(x)))
        diff = self.conv5(x)
        im_recon = diff + im_und
        return im_recon


# test a forward
im_mock = torch.randn(5, 2, 20, 192, 192)
net = MultiLayerCNN()
out = net(im_mock)
print(out.shape)

torch.Size([5, 2, 20, 192, 192])


## Train

In [9]:
from modules.utils import compute_psnr, compute_ssim
class MSELoss():
    def __call__(self, im_recon, im_gt):
        """
        - im_recon: tensor[B, C=2, T, H, W]
        - im_gt: tensor[B, C=2, T, H, W]
        """
        B, C, T, H, W = im_recon.shape
        x = pseudo2real(im_recon)  # [B, T, H, W]
        y = pseudo2real(im_gt)     # [B, T, H, W]
        loss = torch.mean((y - x) ** 2) * B
        return loss

# class SSIM_Loss():
#     def __call__(self, im_recon, im_gt):
#         im_recon = pseudo2real(im_recon)
#         im_gt = pseudo2real(im_gt)
#         psnr_val = [(1- compute_ssim(im_recon[i], im_gt[i], is_minmax=True)) for i in range(2)]
#         psnr_val = torch.tensor(psnr_val)
#         return torch.mean(psnr_val)
        
    
    
# test for loss computation
mse = MSELoss()
x_mock = torch.randn(5, 2, 20, 192, 192)
y_mock = torch.randn(5, 2, 20, 192, 192)
print(mse(x_mock, y_mock))

tensor(4.2932)


In [None]:
# Define network
net = MRIReconstructionFramework(
    recon_net=MultiLayerCNN()
)
# net = MRIReconstructionFramework(ReconstructionNet())
# net = MRIReconstructionFramework(UNet3D(2,2))
# checkpoints =torch.load("03-26_12-50-44-cas_without2-checkpoint-epoch50.pth")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# net = MRIReconstructionFramework(ReconstructionNet())
# net.load_state_dict(checkpoints)
# net = net.to(device)
# optimizer=torch.optim.Adam(
#         net.parameters(),
#         lr=0.0001,
#     )
optimizer=torch.optim.Adam(
        net.parameters(),
        lr=0.001,
    )
# Training & Validation
solver = Solver(
    model=net,
    optimizer=optimizer,
    criterion=MSELoss(),
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size  = 400, gamma = 0.8)
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda= lambda epoch: 1/(epoch+1))
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.8, last_epoch=-1)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.1, last_epoch=-1)
)

epochs_to_train = 20
solver.train(epochs_to_train, train_loader, val_loader=val_loader)
solver.validate(test_loader)

In [None]:
# changing data_index to see diffferent sample's visualization
data_index = 5 # range in [0, 60)
time_index = 13
solver.visualize(test_loader, idx=data_index,time_index=time_index, dpi=100)

## Output

In [23]:
import time
timestamp = time.strftime("%m-%d_%H-%M-%S", time.localtime())

In [41]:
torch.save(net.state_dict(), f'{timestamp}-cas_without2-checkpoint-epoch{epochs_to_train}.pth')

In [156]:
import itertools
batch_num = 5
batch = next(itertools.islice(test_loader, batch_num, None))
x_und, und_mask, image_gt = batch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
x_und = x_und.to(device)
und_mask = und_mask.to(device)
net.eval()
im_recon = net(x_und, und_mask)
# for batch in test_loader:
#     # convert the batch to a numpy array
#     # do something with the numpy array
#     print(batch[0].shape)
#     net(batch[0], batch[0])

In [157]:
im_recon = pseudo2real(im_recon).detach().cpu().numpy()
image_gt = pseudo2real(image_gt).detach().cpu().numpy()

In [None]:
idx = 1
imsshow(im_recon[idx], num_col=5, cmap='gray', is_colorbar=True)
import imageio

# Create a numpy array with a sequence of images

# Write the numpy array to a GIF file using imageio
output = (im_recon[idx] * 255).astype('uint8')
imageio.mimsave(f'animation_{batch_num}_{idx}.gif', output, fps=10)

In [None]:
from modules.utils import compute_psnr, compute_ssim
print(im_recon.shape)
print(f"psnr is {compute_psnr(im_recon[0][0], image_gt[0][0], is_minmax=True):.2f}")
print(f"psnr is {compute_ssim(im_recon[0][0], image_gt[0][0], is_minmax=True):.2f}")

## GIF

In [None]:
checkpoints =torch.load("03-26_12-50-44-cas_without-checkpoint-epoch50.pth")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = MRIReconstructionFramework(ReconstructionNet())
net.load_state_dict(checkpoints)
net = net.to(device)
net.eval()

In [None]:
import time
import imageio
import itertools
from modules.utils import compute_psnr, compute_ssim


net = net.to(device)
# net.eval()
PSNR = []
SSIM = []
for batch_num in range(15):
    print(batch_num, end=' ')
    batch = next(itertools.islice(test_loader, batch_num, None))
    x_und, und_mask, image_gt = batch
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    x_und = x_und.to(device)
    und_mask = und_mask.to(device)
    im_recon = net(x_und, und_mask)
    x_und    = pseudo2real(x_und).detach().cpu().numpy()
    im_recon = pseudo2real(im_recon).detach().cpu().numpy()
    image_gt = pseudo2real(image_gt).detach().cpu().numpy()
    for idx in range(2):
        for time in range(20):
            psnr_val = compute_psnr(im_recon[idx][time], image_gt[idx][time], is_minmax=True)
            ssim_val = compute_ssim(im_recon[idx][time], image_gt[idx][time], is_minmax=True)
            PSNR.append(psnr_val)
            SSIM.append(ssim_val)
#         output = (x_und[idx] * 255).astype('uint8')
#         imageio.mimsave(f'./gif_und/animation_{batch_num}_{idx}_{PSNR[-1]:.2f}.gif', output, fps=10)

In [None]:
PSNR_average = sum(PSNR) / len(PSNR)
PSNR_variance = np.var(PSNR)
SSIM_average = sum(SSIM) / len(SSIM)
SSIM_variance = np.var(SSIM)
print(f"PSNR平均值是:{PSNR_average:.4f}, PSNR方差是:{PSNR_variance:.4f}")
print(f"SSIM平均值是:{SSIM_average:.4f}, SSIM方差是:{SSIM_variance:.4f}")

## Offset

In [23]:
batch = next(itertools.islice(test_loader, 3, None))
x_und, und_mask, image_gt = batch
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
x_und = x_und.to(device)
und_mask = und_mask.to(device)
im_recon = net(x_und, und_mask)
im_recon = pseudo2real(im_recon).detach().cpu().numpy()
image_gt = pseudo2real(image_gt).detach().cpu().numpy()

In [None]:
id = 1
imsshow(image_gt[id] - im_recon[id], num_col=5, is_colorbar=True)