In [1]:
import h5py
import numpy as np
from matplotlib import pyplot as plt
import fastmri
from CLPmodel import *

from fastmri.data.subsample import RandomMaskFunc
from fastmri.data import subsample
from fastmri.data import transforms, mri_data

import torch.cuda

from fastmri.evaluate import *

In [2]:
mask_func = RandomMaskFunc(center_fractions=[0.04], accelerations=[8])

def data_transform(kspace, mask, target, data_attributes, filename, slice_num):
    # Transform the data into appropriate format

    ifft_kspace = fastmri.ifft2c(T.to_tensor(kspace))
    crop_kspace = T.complex_center_crop(ifft_kspace, (320,320))   #torch.Size([640, 372, 2])
    orig_kspace = fastmri.fft2c(crop_kspace)
    masked_kspace, mask = T.apply_mask(orig_kspace, mask_func)  # apply the mask to k-space
                                                                # undersampled k-space data
    mr_img =  fastmri.ifft2c(masked_kspace)    # undersampled MR image
    
    return mr_img, masked_kspace, mask, target

dataset = mri_data.SliceDataset(
    root=pathlib.Path('./trainingset'),
    transform=data_transform,
    challenge='singlecoil'
)

In [3]:
%%time
import torch.optim as optim

net = Net().cuda()

criterion = nn.L1Loss()
optimizer = optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.95, amsgrad=False)

psnr = PSNR()
ssim = SSIM()

avg_psnr_f = 0.0
avg_ssim_f = 0.0

for epoch in range(20):  # loop over the dataset multiple times
    
    running_loss = 0.0
    count_slice = 0
    avg_psnr = 0.0
    avg_ssim = 0.0
    
    for mr_img, masked_kspace, mask, target in dataset:
        
        input1 = mr_img.unsqueeze(0).unsqueeze(0)
        input2 = masked_kspace.unsqueeze(0).unsqueeze(0)

        outputs = net(input1.cuda(), input2.cuda(), mask.cuda())

        abs1 = fastmri.complex_abs(outputs[0][0])
        abs2 = transforms.to_tensor(target).cuda()

        loss = criterion(abs1, abs2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()

        avg_psnr += psnr(abs1, abs2).item()
        avg_ssim += ssim(abs1, abs2).item()

        count_slice += 1
        
        pass
    
    fig = plt.figure()
    plt.imshow(abs1.cpu().detach().numpy(), cmap='gray')
    plt.savefig('outputMRI/epoch'+str(epoch+1)+'.jpg')
    plt.close()
    
    avg_psnr_f += avg_psnr/count_slice
    avg_ssim_f += avg_ssim/count_slice
    
    print('Epoch n° (%d) loss: %.3f' % (epoch + 1, running_loss))
    
    print("Average psnr: ", avg_psnr /count_slice)
    print("Average ssim: ", avg_ssim /count_slice)
    
print('Finished Training')
print("Average psnr tot: ", avg_psnr_f /20)
print("Average ssim tot: ", avg_ssim_f /20)

Epoch n° (1) loss: 0.005
Average psnr:  21.310162841448165
Average ssim:  0.3772567137766248
Epoch n° (2) loss: 0.005
Average psnr:  21.351869215925646
Average ssim:  0.37924570928732776
Epoch n° (3) loss: 0.005
Average psnr:  21.34784204081485
Average ssim:  0.3785827343359878
Epoch n° (4) loss: 0.005
Average psnr:  21.30256297647788
Average ssim:  0.3780316072013074
Epoch n° (5) loss: 0.005
Average psnr:  21.303683410390924
Average ssim:  0.3770450601080778
Epoch n° (6) loss: 0.005
Average psnr:  21.310557262388954
Average ssim:  0.3774152050966041
Epoch n° (7) loss: 0.005
Average psnr:  21.32664200605778
Average ssim:  0.3778354526701231
Epoch n° (8) loss: 0.005
Average psnr:  21.325336303076917
Average ssim:  0.3781632274451183
Epoch n° (9) loss: 0.005
Average psnr:  21.3339732082927
Average ssim:  0.37904042289321443
Epoch n° (10) loss: 0.005
Average psnr:  21.325986233444425
Average ssim:  0.3776469751470142
Epoch n° (11) loss: 0.005
Average psnr:  21.346379157248627
Average ssim

In [4]:
PATH = './saved_net.pth'
torch.save(net.state_dict(), PATH)

In [5]:
print(count_slice)

361
