In [11]:
%matplotlib notebook
import os, sys
import logging
import random
import h5py
import shutil
import time
import argparse
import numpy as np
import sigpy.plot as pl
import torch
import sigpy as sp
import torchvision
from torch import optim
from tensorboardX import SummaryWriter
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib
# import custom libraries
from utils import transforms as T
from utils import subsample as ss
from utils import complex_utils as cplx
from utils.resnet2p1d import generate_model
from utils.flare_utils import roll
from utils import data_ut as dut
# import custom classes
from utils.datasets import SliceData
from subsample_fastmri import MaskFunc
from MoDL_single import UnrolledModel
import argparse
import matplotlib.pyplot as plt
%matplotlib inline
import nibabel as nib
from models.SAmodel import MyNetwork
from models.Unrolled import Unrolled
from models.UnrolledRef import UnrolledRef
from models.UnrolledTransformer import UnrolledTrans

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
%load_ext autoreload
%autoreload 0

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
!pwd

/home/tal/docker/MoDLsinglechannel/modl_singlechannel_reference


In [13]:
!which python3

/home/tal/docker/dockvenv/bin/python3


In [14]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [15]:
checkpoint_file = "./L2_checkpoints_poisson_x4_SAunrolledOF/model_10.pt"
checkpoint = torch.load(checkpoint_file,map_location=device)

In [16]:
params = checkpoint["params"]
#single_MoDL = UnrolledModel(params).to(device)
#single_MoDL = MyNetwork(2,2).to(device)
single_MoDL = Unrolled(params).to(device)
#single_MoDL = UnrolledRef(params).to(device)
#single_MoDL = UnrolledTrans(params).to(device)
single_MoDL.load_state_dict(checkpoint['model'])

shared weights


<All keys matched successfully>

In [17]:
class DataTransform:
    """
    Data Transformer for training unrolled reconstruction models.
    """

    def __init__(self, mask_func, args, use_seed=False):
        self.mask_func = mask_func
        self.use_seed = use_seed
        self.rng = np.random.RandomState()

    def __call__(self, kspace, target, reference, reference_kspace,slice):
       
        im_lowres = abs(sp.ifft(sp.resize(sp.resize(kspace,(172,24)),(172,108))))
        magnitude_vals = im_lowres.reshape(-1)
        k = int(round(0.05 * magnitude_vals.shape[0]))
        scale = magnitude_vals[magnitude_vals.argsort()[::-1][k]]
        kspace = kspace/scale

        # Convert everything from numpy arrays to tensors
        kspace_torch = cplx.to_tensor(kspace).float()   
        target_torch = cplx.to_tensor(target).float() / scale
        
        # Use poisson mask instead
        mask2 = sp.mri.poisson((172,108), 2, calib=(18, 14), dtype=float, crop_corner=False, return_density=True, seed=0, max_attempts=6, tol=0.01)
        mask_torch = torch.stack([torch.tensor(mask2).float(),torch.tensor(mask2).float()],dim=2)
    
        #kspace_torch = T.kspace_cut(mask_torch,0.5)
        kspace_torch = T.awgn_torch(kspace_torch,15,L=1)
        kspace_torch = kspace_torch*mask_torch

    
        ### Reference addition ###
        im_lowres_ref = abs(sp.ifft(sp.resize(sp.resize(reference_kspace,(172,24)),(172,108))))
        magnitude_vals_ref = im_lowres_ref.reshape(-1)
        k_ref = int(round(0.05 * magnitude_vals_ref.shape[0]))
        scale_ref = magnitude_vals_ref[magnitude_vals_ref.argsort()[::-1][k_ref]]
        reference_torch = cplx.to_tensor(reference).float()/ scale_ref
        # Resolution degrading
       
        return kspace_torch, target_torch,mask_torch, reference_torch

In [18]:
def create_datasets(args):
    # Generate k-t undersampling masks
    train_mask = MaskFunc([0.08],[4])
    train_data = SliceData(
        root=str(args.data_path),
        transform=DataTransform(train_mask, args),
        sample_rate=1
    )
    return train_data
def create_data_loaders(args):
    train_data = create_datasets(args)
#     print(train_data[0])

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader
def build_optim(args, params):
    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
    return optimizer

In [67]:
import numpy as np
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr, normalized_root_mse as nrmse
from skimage import img_as_float
from types import SimpleNamespace as Namespace


# Hyperparameters
params = Namespace()
params.data_path = "./test_data/"
params.batch_size = 1
params.num_grad_steps = 4
params.num_cg_steps = 8
params.share_weights = True
params.modl_lamda = 0.05
params.lr = 0.0001
params.weight_decay = 0
params.lr_step_size = 10
params.lr_gamma = 0.5
params.epoch = 21
params.reference_mode = 0
params.reference_lambda = 0.1

# Load test data
test_loader = create_data_loaders(params)
checkpoint = torch.load(checkpoint_file, map_location=device)
single_MoDL = Unrolled(params).to(device)

# Initialize lists to store metrics
mse_in_list, mse_out_list = [], []
psnr_in_list, psnr_out_list = [], []
ssim_in_list, ssim_out_list = [], []

single_MoDL.eval()  # Set model to evaluation mode

with torch.no_grad():  # Disable gradient computation for evaluation
    for data in test_loader:
        input, target, mask, reference = data
        input = input.to(device)
        reference = reference.to(device)
        
        # Forward pass through the model
        output = single_MoDL(input.float(),reference)
        
        # Convert tensors to numpy arrays
        cplx_image_target = np.abs(cplx.to_numpy(T.fft2(target.cpu()))).squeeze(0)
        cplx_image_in = np.abs(cplx.to_numpy(input.cpu()).squeeze(0))
        cplx_image_out = np.abs(cplx.to_numpy(T.fft2(output.cpu())).squeeze(0))

        img_target = img_as_float(np.abs(cplx_image_target))/np.max(np.abs(cplx_image_target))
        img_in = img_as_float(np.abs(cplx_image_in)/np.max(cplx_image_in))/np.max(np.abs(cplx_image_in))
        img_out = img_as_float(np.abs(cplx_image_out))/np.max(np.abs(cplx_image_out))

        # Calculate metrics
        # Calculate SSIM
        data_range = img_target.max() - img_target.min()
        ssim_in, _ = ssim(img_target, img_in, data_range=data_range, full=True)
        ssim_out, _ = ssim(img_target, img_out, data_range=data_range, full=True)
        
        # Calculate PSNR
        psnr_in = T.PSNR(cplx.to_tensor(img_target).unsqueeze(0),cplx.to_tensor( img_in).unsqueeze(0))
        psnr_out = T.PSNR(cplx.to_tensor(img_target).unsqueeze(0), cplx.to_tensor(img_out).unsqueeze(0))

        # Calculate MSE
        mse_in = np.mean(np.abs(cplx_image_in-cplx_image_target)**2)
        mse_out = np.mean(np.abs(cplx_image_out-cplx_image_target)**2)

        # Append metrics to lists
        mse_in_list.append(mse_in)
        mse_out_list.append(mse_out)
        psnr_in_list.append(psnr_in.numpy())
        psnr_out_list.append(psnr_out.numpy())
        ssim_in_list.append(ssim_in)
        ssim_out_list.append(ssim_out)


# Calculate and print average metrics
print(f'Average MSE input: {np.mean(mse_in_list)}')
print(f'Average MSE output: {np.mean(mse_out_list)}')
print(f'Average PSNR input: {np.mean(psnr_in_list)}')
print(f'Average PSNR output: {np.mean(psnr_out_list)}')
print(f'Average SSIM input: {np.mean(ssim_in_list):.4f}')
print(f'Average SSIM output: {np.mean(ssim_out_list):.4f}')


shared weights
Average MSE input: 0.006740964591199441
Average MSE output: 0.004576812963932753
Average PSNR input: 38.877258084369274
Average PSNR output: 55.392005920410156
Average SSIM input: 0.9041
Average SSIM output: 0.9913
