# Coursework for MRI reconstruction (Autumn 2019)

In this tutorial, we provide the data loader to read and process the MRI data in order to ease the difficulty of training your network. By providing this, we hope you focus more on methodology development. Please feel free to change it to suit what you need.

In [12]:
import h5py, os
from functions import transforms as T
from functions.subsample import MaskFunc
from scipy.io import loadmat
from torch.utils.data import DataLoader
import numpy as np
import torch
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.optim as optim

In [13]:
def show_slices(data, slice_nums, cmap=None): # visualisation
    fig = plt.figure(figsize=(15,10))
    for i, num in enumerate(slice_nums):
        plt.subplot(1, len(slice_nums), i + 1)
        plt.imshow(data[num], cmap=cmap)
        plt.axis('off')

In [14]:
class MRIDataset(DataLoader):
    def __init__(self, data_list, acceleration, center_fraction, use_seed):
        self.data_list = data_list
        self.acceleration = acceleration
        self.center_fraction = center_fraction
        self.use_seed = use_seed

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        subject_id = self.data_list[idx]
        return get_epoch_batch(subject_id, self.acceleration, self.center_fraction, self.use_seed)

In [15]:
def get_epoch_batch(subject_id, acc, center_fract, use_seed=True):
    ''' random select a few slices (batch_size) from each volume'''

    fname, rawdata_name, slice = subject_id  
    
    with h5py.File(rawdata_name, 'r') as data:
        rawdata = data['kspace'][slice]
                      
    slice_kspace = T.to_tensor(rawdata).unsqueeze(0)
    S, Ny, Nx, ps = slice_kspace.shape

    # apply random mask
    shape = np.array(slice_kspace.shape)
    mask_func = MaskFunc(center_fractions=[center_fract], accelerations=[acc])
    seed = None if not use_seed else tuple(map(ord, fname))
    mask = mask_func(shape, seed)
      
    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), slice_kspace)
    masks = mask.repeat(S, Ny, 1, ps)

    img_gt, img_und = T.ifft2(slice_kspace), T.ifft2(masked_kspace)
    
    
    # perform data normalization which is important for network to learn useful features
    # during inference there is no ground truth image so use the zero-filled recon to normalize
    norm = T.complex_abs(img_und).max()
    if norm < 1e-6: norm = 1e-6
    
    # normalized data
    img_gt, img_und, rawdata_und = img_gt/norm, img_und/norm, masked_kspace/norm
    
#    img_gt = T.center_crop(T.complex_abs(img_gt), [320, 320]).unsqueeze(1)
#    img_und = T.center_crop(T.complex_abs(img_und), [320, 320]).unsqueeze(1)
#     rawdata_und = T.center_crop(T.complex_abs(rawdata_und), [320, 320]).unsqueeze(1)
#     norm = T.center_crop(T.complex_abs(norm), [320, 320]).unsqueeze(1)
#     masks.T.center_crop(T.complex_abs(masks), [320, 320]).unsqueeze(1)    

    img_gt = T.center_crop(T.complex_abs(img_gt), [320, 320])
    img_und = T.center_crop(T.complex_abs(img_und), [320, 320])
        
    return img_gt.squeeze(0), img_und.squeeze(0)


In [16]:
def load_data_path(train_data_path, val_data_path):
    """ Go through each subset (training, validation) and list all 
    the file names, the file paths and the slices of subjects in the training and validation sets 
    """

    data_list = {}
    train_and_val = ['train', 'val']
    data_path = [train_data_path, val_data_path]
      
    for i in range(len(data_path)):

        data_list[train_and_val[i]] = []
        
        which_data_path = data_path[i]
    
        for fname in sorted(os.listdir(which_data_path)):
            
            subject_data_path = os.path.join(which_data_path, fname)
                     
            if not os.path.isfile(subject_data_path): continue 
            
            with h5py.File(subject_data_path, 'r') as data:
                num_slice = data['kspace'].shape[0]
                
            # the first 5 slices are mostly noise so it is better to exlude them
            data_list[train_and_val[i]] += [(fname, subject_data_path, slice) for slice in range(5, num_slice)]
    
    return data_list    

In [17]:
class AlexNet(nn.Module):

    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), #320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), #320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), #320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), #320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), #320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),  # 320/320
        )

    def forward(self, x):
        x = self.features(x)
        #x = nn.functional.sigmoid(x)
        #x = x * 255
        #x = x.type(torch.cuda.int32)
        return x

In [20]:


if __name__ == '__main__':
    
    data_path_train = '/tmp/NC2019MRI/train'
    data_path_val = '/tmp/NC2019MRI/train'
    data_list = load_data_path(data_path_train, data_path_val) # first load all file names, paths and slices.
    
    acc = 8
    cen_fract = 0.04
    seed = False # random masks for each slice 
    num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    
    lr = 1e-3
    
    network = AlexNet()
    network.to('cuda:0') #move the model on the GPU
    mse_loss = nn.MSELoss().to('cuda:0')
    
    optimizer = optim.Adam(network.parameters(), lr=lr)
    
    # create data loader for training set. It applies same to validation set as well
    train_dataset = MRIDataset(data_list['train'], acceleration=acc, center_fraction=cen_fract, use_seed=seed)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1, num_workers=num_workers) 
    

    j = 0
    for iteration, sample in enumerate(train_loader):
        
        img_gt, img_und = sample
        print(img_gt.shape)
        
        
#         img_gt, img_und, rawdata_und, masks, norm = sample
#         img_gt = T.center_crop(T.complex_abs(img_gt), [320, 320]).unsqueeze(1).to('cuda:0')
#         img_und = T.center_crop(T.complex_abs(img_und), [320, 320]).unsqueeze(1).to('cuda:0')

        
#         output = network(img_und)       #feedforward
        
#         print(output.shape)

#         loss = mse_loss(output, img_gt)
#         optimizer.zero_grad()       #set current gradients to 0
#         loss.backward()      #backpropagate
#         optimizer.step()     #update the weights
#         print(loss.item(), "  ")
        
#         i = 0
#         j +=1
        
#         if j%100 == 0:
#             for row in range(0,320):
#                 for col in range(0,320):
#                     if output[0,0,row,col].item() == img_gt[0,0,row,col].item():

#                         i +=1
#             print(i, "\n \n")
                
#         print(img_gt.shape)
#         print(img_und.shape)
        
#         # stack different slices into a volume for visualisation
#         A = masks[...,0].squeeze()
#         B = torch.log(T.complex_abs(rawdata_und) + 1e-9).squeeze()
#         C = T.complex_abs(img_und).squeeze()
#         D = T.complex_abs(img_gt).squeeze()
#         all_imgs = torch.stack([A,B,C,D], dim=0)

#         # from left to right: mask, masked kspace, undersampled image, ground truth
#         show_slices(all_imgs, [0, 1, 2, 3], cmap='gray')
#         plt.pause(1)

#         if iteration >= 0: break  # show 4 random slices
        

ValueError: too many values to unpack (expected 2)

In [8]:
acc = 8
cen_fract = 0.04
seed = False # random masks for each slice 
num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    

if __name__ == '__main__':
    
    data_path_train = '/tmp/NC2019MRI/train'
    data_path_val = '/tmp/NC2019MRI/train'
    data_list = load_data_path(data_path_train, data_path_val) # first load all file names, paths and slices.
    
    acc = 8
    cen_fract = 0.04
    seed = False # random masks for each slice 
    num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    # create data loader for training set. It applies same to validation set as well
    train_dataset = MRIDataset(data_list['train'], acceleration=acc, center_fraction=cen_fract, use_seed=seed)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1, num_workers=num_workers) 
    

    a = [[],[]]
    for iteration, sample in enumerate(train_loader):
        img_gt, img_und, rawdata_und, masks, norm = sample
        img_gt = T.center_crop(T.complex_abs(img_gt), [320, 320]).unsqueeze(1)
        img_und = T.center_crop(T.complex_abs(img_und), [320, 320]).unsqueeze(1)
        a[0].append(img_und)
        a[1].append(img_gt)
    b = torch.cat(a[0][:])
    c = torch.cat(a[1][:])
train = torch.stack((b,c),dim=0)
del a
del b
del c
del train_loader
del train_dataset
train.shape

torch.Size([2, 2134, 1, 320, 320])

In [11]:
lr = 1e-3
    
network = AlexNet()
network.to('cuda:0') #move the model on the GPU
mse_loss = nn.MSELoss().to('cuda:0')
    
optimizer = optim.Adam(network.parameters(), lr=lr)
train_loader = DataLoader(train, shuffle=True, batch_size=1, num_workers=num_workers) 
for iteration, sample in enumerate(train_loader):
    #img_gt, img_und, rawdata_und, masks, norm = sample        
    
    output = network(img_und)       #feedforward
    print(output.shape)

    loss = mse_loss(output, img_gt)
    optimizer.zero_grad()       #set current gradients to 0
    loss.backward()      #backpropagate
    optimizer.step()     #update the weights
    print(loss.item(), "  ")
        
    i = 0
    j +=1
        
    if j%100 == 0:
        for row in range(0,320):
            for col in range(0,320):
                if output[0,0,row,col].item() == img_gt[0,0,row,col].item():

                        i +=1
        print(i, "\n \n")

ValueError: not enough values to unpack (expected 5, got 1)

False