In [None]:
import h5py
import numpy as np
from matplotlib import pyplot as plt
import fastmri
from CLPmodel import *

from fastmri.data.subsample import RandomMaskFunc
from fastmri.data import subsample
from fastmri.data import transforms, mri_data

import torch.cuda

from fastmri.evaluate import *

In [None]:
mask_func = RandomMaskFunc(center_fractions=[0.04], accelerations=[8])

def data_transform(kspace, mask, target, data_attributes, filename, slice_num):
    # Transform the data into appropriate format
   
    ifft_kspace = fastmri.ifft2c(transforms.to_tensor(kspace))
    crop_kspace = transforms.complex_center_crop(ifft_kspace, (320,320))    # torch.Size([320, 320, 2])
    orig_kspace = fastmri.fft2c(crop_kspace)
    masked_kspace, mask = transforms.apply_mask(orig_kspace, mask_func)     # apply the mask to k-space
                                                                   # undersampled k-space data
    mr_img = fastmri.ifft2c(masked_kspace)         # undersampled MR image
   
    return mr_img, masked_kspace, mask, target

dataset = mri_data.SliceDataset(
    root=pathlib.Path('./testset'),
    transform=data_transform,
    challenge='singlecoil'
)

In [None]:
PATH = './saved_net.pth'

net = Net().cuda()
net.load_state_dict(torch.load(PATH))

In [None]:
%%time
import torch.optim as optim

net = Net().cuda()

psnr = PSNR()
ssim = SSIM()

count_slice = 0.0
    
avg_psnr = 0.0
avg_ssim = 0.0

for mr_img, masked_kspace, mask, target in dataset:
            
    input1 = mr_img.unsqueeze(0).unsqueeze(0)
    input2 = masked_kspace.unsqueeze(0).unsqueeze(0)

    outputs = net(input1.cuda(), input2.cuda(), mask.cuda())

    abs1 = fastmri.complex_abs(outputs[0][0])
    abs2 = transforms.to_tensor(target).cuda()

    avg_psnr += psnr(abs1, abs2).item()
    avg_ssim += ssim(abs1, abs2).item()
    
    if(count_slice%30==0):
        fig = plt.figure()
        plt.imshow(abs1.cpu().detach().numpy(), cmap='gray')
        plt.savefig('outputMRI/img'+str(count_slice+1)+'.jpg')
        plt.close()
        
        fig = plt.figure()
        plt.imshow(abs2.cpu().detach().numpy(), cmap='gray')
        plt.savefig('outputMRI/imggg'+str(count_slice+1)+'.jpg')
        plt.close()
        
    count_slice += 1

    pass

print("Average psnr: ", avg_psnr / count_slice)
print("Average ssim: ", avg_ssim / count_slice)
print("Count_slice: ", count_slice)
print('Finished validation')