# Coursework for MRI reconstruction (Autumn 2019)

In this tutorial, we provide the data loader to read and process the MRI data in order to ease the difficulty of training your network. By providing this, we hope you focus more on methodology development. Please feel free to change it to suit what you need.

In [12]:
import h5py, os
from functions import transforms as T
from functions.subsample import MaskFunc
from scipy.io import loadmat
from torch.utils.data import DataLoader
import numpy as np
import torch
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.optim as optim

In [13]:
def show_slices(data, slice_nums, cmap=None): # visualisation
    fig = plt.figure(figsize=(15,10))
    for i, num in enumerate(slice_nums):
        plt.subplot(1, len(slice_nums), i + 1)
        plt.imshow(data[num], cmap=cmap)
        plt.axis('off')

In [14]:
class MRIDataset(DataLoader):
    def __init__(self, data_list, acceleration, center_fraction, use_seed):
        self.data_list = data_list
        self.acceleration = acceleration
        self.center_fraction = center_fraction
        self.use_seed = use_seed

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        subject_id = self.data_list[idx]
        return get_epoch_batch(subject_id, self.acceleration, self.center_fraction, self.use_seed)

In [15]:
def get_epoch_batch(subject_id, acc, center_fract, use_seed=True):
    ''' random select a few slices (batch_size) from each volume'''

    fname, rawdata_name, slice = subject_id  
    
    with h5py.File(rawdata_name, 'r') as data:
        rawdata = data['kspace'][slice]
                      
    slice_kspace = T.to_tensor(rawdata).unsqueeze(0)
    S, Ny, Nx, ps = slice_kspace.shape

    # apply random mask
    shape = np.array(slice_kspace.shape)
    mask_func = MaskFunc(center_fractions=[center_fract], accelerations=[acc])
    seed = None if not use_seed else tuple(map(ord, fname))
    mask = mask_func(shape, seed)
      
    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), slice_kspace)
    masks = mask.repeat(S, Ny, 1, ps)

    img_gt, img_und = T.ifft2(slice_kspace), T.ifft2(masked_kspace)
    
    
    # perform data normalization which is important for network to learn useful features
    # during inference there is no ground truth image so use the zero-filled recon to normalize
    norm = T.complex_abs(img_und).max()
    if norm < 1e-6: norm = 1e-6
    
    # normalized data
    img_gt, img_und, rawdata_und = img_gt/norm, img_und/norm, masked_kspace/norm
    
#    img_gt = T.center_crop(T.complex_abs(img_gt), [320, 320]).unsqueeze(1)
#    img_und = T.center_crop(T.complex_abs(img_und), [320, 320]).unsqueeze(1)
#     rawdata_und = T.center_crop(T.complex_abs(rawdata_und), [320, 320]).unsqueeze(1)
#     norm = T.center_crop(T.complex_abs(norm), [320, 320]).unsqueeze(1)
#     masks.T.center_crop(T.complex_abs(masks), [320, 320]).unsqueeze(1)    
        
    return img_gt.squeeze(0), img_und.squeeze(0), rawdata_und.squeeze(0), masks.squeeze(0), norm


In [16]:
def load_data_path(train_data_path, val_data_path):
    """ Go through each subset (training, validation) and list all 
    the file names, the file paths and the slices of subjects in the training and validation sets 
    """

    data_list = {}
    train_and_val = ['train', 'val']
    data_path = [train_data_path, val_data_path]
      
    for i in range(len(data_path)):

        data_list[train_and_val[i]] = []
        
        which_data_path = data_path[i]
    
        for fname in sorted(os.listdir(which_data_path)):
            
            subject_data_path = os.path.join(which_data_path, fname)
                     
            if not os.path.isfile(subject_data_path): continue 
            
            with h5py.File(subject_data_path, 'r') as data:
                num_slice = data['kspace'].shape[0]
                
            # the first 5 slices are mostly noise so it is better to exlude them
            data_list[train_and_val[i]] += [(fname, subject_data_path, slice) for slice in range(5, num_slice)]
    
    return data_list    

In [17]:
class AlexNet(nn.Module):

    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), #320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), #320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), #320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), #320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), #320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),  # 320/320
        )

    def forward(self, x):
        x = self.features(x)
        #x = nn.functional.sigmoid(x)
        #x = x * 255
        #x = x.type(torch.cuda.int32)
        return x

In [19]:


if __name__ == '__main__':
    
    data_path_train = '/tmp/NC2019MRI/train'
    data_path_val = '/tmp/NC2019MRI/train'
    data_list = load_data_path(data_path_train, data_path_val) # first load all file names, paths and slices.
    
    acc = 8
    cen_fract = 0.04
    seed = False # random masks for each slice 
    num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    
    lr = 1e-3
    
    network = AlexNet()
    network.to('cuda:0') #move the model on the GPU
    mse_loss = nn.MSELoss().to('cuda:0')
    
    optimizer = optim.Adam(network.parameters(), lr=lr)
    
    # create data loader for training set. It applies same to validation set as well
    train_dataset = MRIDataset(data_list['train'], acceleration=acc, center_fraction=cen_fract, use_seed=seed)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1, num_workers=num_workers) 
    

    j = 0
    for iteration, sample in enumerate(train_loader):
        
        
        
        img_gt, img_und, rawdata_und, masks, norm = sample
        img_gt = T.center_crop(T.complex_abs(img_gt), [320, 320]).unsqueeze(1).to('cuda:0')
        img_und = T.center_crop(T.complex_abs(img_und), [320, 320]).unsqueeze(1).to('cuda:0')

        
        output = network(img_und)       #feedforward
        
        print(output.shape)

        loss = mse_loss(output, img_gt)
        optimizer.zero_grad()       #set current gradients to 0
        loss.backward()      #backpropagate
        optimizer.step()     #update the weights
        print(loss.item(), "  ")
        
        i = 0
        j +=1
        
        if j%100 == 0:
            for row in range(0,320):
                for col in range(0,320):
                    if output[0,0,row,col].item() == img_gt[0,0,row,col].item():

                        i +=1
            print(i, "\n \n")
                
#         print(img_gt.shape)
#         print(img_und.shape)
        
#         # stack different slices into a volume for visualisation
#         A = masks[...,0].squeeze()
#         B = torch.log(T.complex_abs(rawdata_und) + 1e-9).squeeze()
#         C = T.complex_abs(img_und).squeeze()
#         D = T.complex_abs(img_gt).squeeze()
#         all_imgs = torch.stack([A,B,C,D], dim=0)

#         # from left to right: mask, masked kspace, undersampled image, ground truth
#         show_slices(all_imgs, [0, 1, 2, 3], cmap='gray')
#         plt.pause(1)

#         if iteration >= 0: break  # show 4 random slices
        

torch.Size([1, 1, 320, 320])
0.19762741029262543   
torch.Size([1, 1, 320, 320])
0.15016788244247437   
torch.Size([1, 1, 320, 320])
0.05371299758553505   
torch.Size([1, 1, 320, 320])
0.026851551607251167   
torch.Size([1, 1, 320, 320])
0.05690575763583183   
torch.Size([1, 1, 320, 320])
0.020738378167152405   
torch.Size([1, 1, 320, 320])
0.023768343031406403   
torch.Size([1, 1, 320, 320])
0.055165261030197144   
torch.Size([1, 1, 320, 320])
0.02021394670009613   
torch.Size([1, 1, 320, 320])
0.03680403158068657   
torch.Size([1, 1, 320, 320])
0.004967518616467714   
torch.Size([1, 1, 320, 320])
0.005652949679642916   
torch.Size([1, 1, 320, 320])
0.028324389830231667   
torch.Size([1, 1, 320, 320])
0.02644411474466324   
torch.Size([1, 1, 320, 320])
0.015069958753883839   
torch.Size([1, 1, 320, 320])
0.023926997557282448   
torch.Size([1, 1, 320, 320])
0.023080969229340553   
torch.Size([1, 1, 320, 320])
0.014897129498422146   
torch.Size([1, 1, 320, 320])
0.02409924939274788   
t

0.010833611711859703   
torch.Size([1, 1, 320, 320])
0.00997605174779892   
torch.Size([1, 1, 320, 320])
0.008599159307777882   
torch.Size([1, 1, 320, 320])
0.014913225546479225   
torch.Size([1, 1, 320, 320])
0.007227741647511721   
torch.Size([1, 1, 320, 320])
0.013068890199065208   
torch.Size([1, 1, 320, 320])
0.015860090032219887   
torch.Size([1, 1, 320, 320])
0.008455795235931873   
torch.Size([1, 1, 320, 320])
0.004275299143046141   
torch.Size([1, 1, 320, 320])
0.008825039491057396   
torch.Size([1, 1, 320, 320])
0.005080179776996374   
torch.Size([1, 1, 320, 320])
0.00400611013174057   
torch.Size([1, 1, 320, 320])
0.005499700084328651   
torch.Size([1, 1, 320, 320])
0.007923593744635582   
torch.Size([1, 1, 320, 320])
0.0103946253657341   
torch.Size([1, 1, 320, 320])
0.01940797083079815   
torch.Size([1, 1, 320, 320])
0.01661764644086361   
torch.Size([1, 1, 320, 320])
0.01168554276227951   
torch.Size([1, 1, 320, 320])
0.01726631261408329   
torch.Size([1, 1, 320, 320])
0

0.008256149478256702   
torch.Size([1, 1, 320, 320])
0.012493710964918137   
torch.Size([1, 1, 320, 320])
0.009755158796906471   
torch.Size([1, 1, 320, 320])
0.011372942477464676   
torch.Size([1, 1, 320, 320])
0.00925787165760994   
torch.Size([1, 1, 320, 320])
0.008630167692899704   
torch.Size([1, 1, 320, 320])
0.01488501112908125   
torch.Size([1, 1, 320, 320])
0.011832425370812416   
torch.Size([1, 1, 320, 320])
0.03038792684674263   
torch.Size([1, 1, 320, 320])
0.004015335813164711   
torch.Size([1, 1, 320, 320])
0.01253080740571022   
torch.Size([1, 1, 320, 320])
0.013093621470034122   
torch.Size([1, 1, 320, 320])
0.013149199075996876   
torch.Size([1, 1, 320, 320])
0.0056561375968158245   
torch.Size([1, 1, 320, 320])
0.009040060453116894   
torch.Size([1, 1, 320, 320])
0.02297799102962017   
torch.Size([1, 1, 320, 320])
0.011743951588869095   
torch.Size([1, 1, 320, 320])
0.014628634788095951   
torch.Size([1, 1, 320, 320])
0.005364877637475729   
torch.Size([1, 1, 320, 320

0.018514154478907585   
torch.Size([1, 1, 320, 320])
0.008734422735869884   
torch.Size([1, 1, 320, 320])
0.004840021021664143   
torch.Size([1, 1, 320, 320])
0.012733411975204945   
torch.Size([1, 1, 320, 320])
0.006797416135668755   
torch.Size([1, 1, 320, 320])
0.019572651013731956   
torch.Size([1, 1, 320, 320])
0.0030657427851110697   
torch.Size([1, 1, 320, 320])
0.02225157804787159   
torch.Size([1, 1, 320, 320])
0.009773332625627518   
torch.Size([1, 1, 320, 320])
0.02966342866420746   
torch.Size([1, 1, 320, 320])
0.01111353375017643   
torch.Size([1, 1, 320, 320])
0.004873605445027351   
torch.Size([1, 1, 320, 320])
0.016523869708180428   
torch.Size([1, 1, 320, 320])
0.006294709630310535   
torch.Size([1, 1, 320, 320])
0.007359260227531195   
torch.Size([1, 1, 320, 320])
0.008319549262523651   
torch.Size([1, 1, 320, 320])
0.017862778156995773   
0 
 

torch.Size([1, 1, 320, 320])
0.016974447295069695   
torch.Size([1, 1, 320, 320])
0.01300343032926321   
torch.Size([1, 1, 3

torch.Size([1, 1, 320, 320])
0.01636374555528164   
torch.Size([1, 1, 320, 320])
0.011318680830299854   
torch.Size([1, 1, 320, 320])
0.005028023850172758   
torch.Size([1, 1, 320, 320])
0.014446995221078396   
torch.Size([1, 1, 320, 320])
0.011641213670372963   
torch.Size([1, 1, 320, 320])
0.006947210058569908   
torch.Size([1, 1, 320, 320])
0.010210229083895683   
torch.Size([1, 1, 320, 320])
0.011458239518105984   
torch.Size([1, 1, 320, 320])
0.01889229379594326   
torch.Size([1, 1, 320, 320])
0.005419281776994467   
torch.Size([1, 1, 320, 320])
0.01086676586419344   
torch.Size([1, 1, 320, 320])
0.00884534977376461   
torch.Size([1, 1, 320, 320])
0.009376616217195988   
torch.Size([1, 1, 320, 320])
0.009922802448272705   
torch.Size([1, 1, 320, 320])
0.008014694787561893   
torch.Size([1, 1, 320, 320])
0.011974465101957321   
torch.Size([1, 1, 320, 320])
0.01967705227434635   
torch.Size([1, 1, 320, 320])
0.014589080587029457   
torch.Size([1, 1, 320, 320])
0.011094962246716022  

0 
 

torch.Size([1, 1, 320, 320])
0.029682636260986328   
torch.Size([1, 1, 320, 320])
0.008506281301379204   
torch.Size([1, 1, 320, 320])
0.008514919318258762   
torch.Size([1, 1, 320, 320])
0.012836115434765816   
torch.Size([1, 1, 320, 320])
0.006493000779300928   
torch.Size([1, 1, 320, 320])
0.003783969907090068   
torch.Size([1, 1, 320, 320])
0.01802266761660576   
torch.Size([1, 1, 320, 320])
0.009944195859134197   
torch.Size([1, 1, 320, 320])
0.026429152116179466   
torch.Size([1, 1, 320, 320])
0.02101719379425049   
torch.Size([1, 1, 320, 320])
0.009322010912001133   
torch.Size([1, 1, 320, 320])
0.014938666485249996   
torch.Size([1, 1, 320, 320])
0.011510479263961315   
torch.Size([1, 1, 320, 320])
0.013556637801229954   
torch.Size([1, 1, 320, 320])
0.01744111441075802   
torch.Size([1, 1, 320, 320])
0.014751252718269825   
torch.Size([1, 1, 320, 320])
0.005744120106101036   
torch.Size([1, 1, 320, 320])
0.0165756456553936   
torch.Size([1, 1, 320, 320])
0.02047047764062

torch.Size([1, 1, 320, 320])
0.01937352493405342   
torch.Size([1, 1, 320, 320])
0.008761405013501644   
torch.Size([1, 1, 320, 320])
0.016490809619426727   
torch.Size([1, 1, 320, 320])
0.008508685044944286   
torch.Size([1, 1, 320, 320])
0.010437838733196259   
torch.Size([1, 1, 320, 320])
0.007129645440727472   
torch.Size([1, 1, 320, 320])
0.00880376435816288   
torch.Size([1, 1, 320, 320])
0.011613298207521439   
torch.Size([1, 1, 320, 320])
0.009954252280294895   
torch.Size([1, 1, 320, 320])
0.006581378635019064   
torch.Size([1, 1, 320, 320])
0.009374206885695457   
torch.Size([1, 1, 320, 320])
0.012247643433511257   
torch.Size([1, 1, 320, 320])
0.008769707754254341   
torch.Size([1, 1, 320, 320])
0.021477198228240013   
torch.Size([1, 1, 320, 320])
0.01031448319554329   
torch.Size([1, 1, 320, 320])
0.007768441457301378   
torch.Size([1, 1, 320, 320])
0.014378603547811508   
torch.Size([1, 1, 320, 320])
0.01256829034537077   
torch.Size([1, 1, 320, 320])
0.00802470650523901  

0.012908983044326305   
torch.Size([1, 1, 320, 320])
0.009996550157666206   
torch.Size([1, 1, 320, 320])
0.009112062864005566   
torch.Size([1, 1, 320, 320])
0.010164485312998295   
torch.Size([1, 1, 320, 320])
0.020701756700873375   
torch.Size([1, 1, 320, 320])
0.008740645833313465   
torch.Size([1, 1, 320, 320])
0.010430428199470043   
torch.Size([1, 1, 320, 320])
0.006895731668919325   
torch.Size([1, 1, 320, 320])
0.0031293611973524094   
torch.Size([1, 1, 320, 320])
0.008285119198262691   
torch.Size([1, 1, 320, 320])
0.007800525985658169   
torch.Size([1, 1, 320, 320])
0.012852136977016926   
torch.Size([1, 1, 320, 320])
0.009884118102490902   
torch.Size([1, 1, 320, 320])
0.026541052386164665   
torch.Size([1, 1, 320, 320])
0.01306945364922285   
torch.Size([1, 1, 320, 320])
0.003416755236685276   
torch.Size([1, 1, 320, 320])
0.006536874920129776   
torch.Size([1, 1, 320, 320])
0.004496005363762379   
torch.Size([1, 1, 320, 320])
0.008864397183060646   
torch.Size([1, 1, 320,

0.013985750265419483   
torch.Size([1, 1, 320, 320])
0.015021469444036484   
torch.Size([1, 1, 320, 320])
0.010113107040524483   
torch.Size([1, 1, 320, 320])
0.004367392510175705   
torch.Size([1, 1, 320, 320])
0.015253565274178982   
torch.Size([1, 1, 320, 320])
0.006780078634619713   
torch.Size([1, 1, 320, 320])
0.005201913882046938   
torch.Size([1, 1, 320, 320])
0.008684467524290085   
torch.Size([1, 1, 320, 320])
0.006460657808929682   
torch.Size([1, 1, 320, 320])
0.01796748861670494   
torch.Size([1, 1, 320, 320])
0.015891242772340775   
torch.Size([1, 1, 320, 320])
0.019380230456590652   
torch.Size([1, 1, 320, 320])
0.011468835175037384   
torch.Size([1, 1, 320, 320])
0.012788905762135983   
torch.Size([1, 1, 320, 320])
0.006567603442817926   
torch.Size([1, 1, 320, 320])
0.011041325516998768   
torch.Size([1, 1, 320, 320])
0.008490213192999363   
torch.Size([1, 1, 320, 320])
0.0049604382365942   
0 
 

torch.Size([1, 1, 320, 320])
0.006542273797094822   
torch.Size([1, 1, 3

0.009871968999505043   
torch.Size([1, 1, 320, 320])
0.004156534560024738   
torch.Size([1, 1, 320, 320])
0.012925305403769016   
torch.Size([1, 1, 320, 320])
0.011478092521429062   
torch.Size([1, 1, 320, 320])
0.016434213146567345   
torch.Size([1, 1, 320, 320])
0.014947030693292618   
torch.Size([1, 1, 320, 320])
0.012949448078870773   
torch.Size([1, 1, 320, 320])
0.0035770279355347157   
torch.Size([1, 1, 320, 320])
0.01269338745623827   
torch.Size([1, 1, 320, 320])
0.008120350539684296   
torch.Size([1, 1, 320, 320])
0.0071282051503658295   
torch.Size([1, 1, 320, 320])
0.003808019682765007   
torch.Size([1, 1, 320, 320])
0.010216818191111088   
torch.Size([1, 1, 320, 320])
0.010573696345090866   
torch.Size([1, 1, 320, 320])
0.0043564364314079285   
torch.Size([1, 1, 320, 320])
0.005972932558506727   
torch.Size([1, 1, 320, 320])
0.0063759684562683105   
torch.Size([1, 1, 320, 320])
0.004571036901324987   
torch.Size([1, 1, 320, 320])
0.003779652528464794   
torch.Size([1, 1, 3

0 
 

torch.Size([1, 1, 320, 320])
0.009286841377615929   
torch.Size([1, 1, 320, 320])
0.007708774879574776   
torch.Size([1, 1, 320, 320])
0.013429123908281326   
torch.Size([1, 1, 320, 320])
0.019310686737298965   
torch.Size([1, 1, 320, 320])
0.012012469582259655   
torch.Size([1, 1, 320, 320])
0.0043474589474499226   
torch.Size([1, 1, 320, 320])
0.02650519832968712   
torch.Size([1, 1, 320, 320])
0.007986977696418762   
torch.Size([1, 1, 320, 320])
0.01803942583501339   
torch.Size([1, 1, 320, 320])
0.005814345087856054   
torch.Size([1, 1, 320, 320])
0.016405832022428513   
torch.Size([1, 1, 320, 320])
0.009916587732732296   
torch.Size([1, 1, 320, 320])
0.014089920558035374   
torch.Size([1, 1, 320, 320])
0.004296397790312767   
torch.Size([1, 1, 320, 320])
0.01570507325232029   
torch.Size([1, 1, 320, 320])
0.02008497156202793   
torch.Size([1, 1, 320, 320])
0.0056082881055772305   
torch.Size([1, 1, 320, 320])
0.013996814377605915   
torch.Size([1, 1, 320, 320])
0.00554830534

0.0050641633570194244   
torch.Size([1, 1, 320, 320])
0.0028296392410993576   
torch.Size([1, 1, 320, 320])
0.003470601513981819   
torch.Size([1, 1, 320, 320])
0.00594797357916832   
torch.Size([1, 1, 320, 320])
0.013933003880083561   
torch.Size([1, 1, 320, 320])
0.012795907445251942   
torch.Size([1, 1, 320, 320])
0.01559969037771225   
torch.Size([1, 1, 320, 320])
0.012285646982491016   
torch.Size([1, 1, 320, 320])
0.012821059674024582   
torch.Size([1, 1, 320, 320])
0.011786825954914093   
torch.Size([1, 1, 320, 320])
0.018903853371739388   
torch.Size([1, 1, 320, 320])
0.011116567067801952   
torch.Size([1, 1, 320, 320])
0.005704541224986315   
torch.Size([1, 1, 320, 320])
0.014238348230719566   
torch.Size([1, 1, 320, 320])
0.002502900082617998   
torch.Size([1, 1, 320, 320])
0.010612010955810547   
torch.Size([1, 1, 320, 320])
0.004371567163616419   
torch.Size([1, 1, 320, 320])
0.011007709428668022   
torch.Size([1, 1, 320, 320])
0.009828979149460793   
torch.Size([1, 1, 320,

0.009582871571183205   
torch.Size([1, 1, 320, 320])
0.024775376543402672   
torch.Size([1, 1, 320, 320])
0.006418372504413128   
torch.Size([1, 1, 320, 320])
0.011194519698619843   
torch.Size([1, 1, 320, 320])
0.00956302136182785   
torch.Size([1, 1, 320, 320])
0.010639174841344357   
torch.Size([1, 1, 320, 320])
0.014242277480661869   
torch.Size([1, 1, 320, 320])
0.008289630524814129   
torch.Size([1, 1, 320, 320])
0.016743414103984833   
torch.Size([1, 1, 320, 320])
0.0265642236918211   
torch.Size([1, 1, 320, 320])
0.008735055103898048   
torch.Size([1, 1, 320, 320])
0.008536129258573055   
torch.Size([1, 1, 320, 320])
0.01097650732845068   
torch.Size([1, 1, 320, 320])
0.014400213025510311   
torch.Size([1, 1, 320, 320])
0.005698395427316427   
torch.Size([1, 1, 320, 320])
0.004090063739567995   
torch.Size([1, 1, 320, 320])
0.022937368601560593   
torch.Size([1, 1, 320, 320])
0.016519175842404366   
torch.Size([1, 1, 320, 320])
0.004780339542776346   
torch.Size([1, 1, 320, 320

torch.Size([1, 1, 320, 320])
0.016582010313868523   
torch.Size([1, 1, 320, 320])
0.00914565660059452   
torch.Size([1, 1, 320, 320])
0.010325596667826176   
torch.Size([1, 1, 320, 320])
0.006744238547980785   
torch.Size([1, 1, 320, 320])
0.011217457242310047   
torch.Size([1, 1, 320, 320])
0.00828289519995451   
torch.Size([1, 1, 320, 320])
0.008898388594388962   
torch.Size([1, 1, 320, 320])
0.00847158208489418   
torch.Size([1, 1, 320, 320])
0.010587646625936031   
torch.Size([1, 1, 320, 320])
0.00694517744705081   
torch.Size([1, 1, 320, 320])
0.020069990307092667   
torch.Size([1, 1, 320, 320])
0.007677307352423668   
torch.Size([1, 1, 320, 320])
0.005119465757161379   
torch.Size([1, 1, 320, 320])
0.0035394374281167984   
torch.Size([1, 1, 320, 320])
0.003617672249674797   
torch.Size([1, 1, 320, 320])
0.006311008706688881   
torch.Size([1, 1, 320, 320])
0.005709254182875156   
torch.Size([1, 1, 320, 320])
0.013077841140329838   
torch.Size([1, 1, 320, 320])
0.009091433137655258

In [8]:
acc = 8
cen_fract = 0.04
seed = False # random masks for each slice 
num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    

if __name__ == '__main__':
    
    data_path_train = '/tmp/NC2019MRI/train'
    data_path_val = '/tmp/NC2019MRI/train'
    data_list = load_data_path(data_path_train, data_path_val) # first load all file names, paths and slices.
    
    acc = 8
    cen_fract = 0.04
    seed = False # random masks for each slice 
    num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    # create data loader for training set. It applies same to validation set as well
    train_dataset = MRIDataset(data_list['train'], acceleration=acc, center_fraction=cen_fract, use_seed=seed)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1, num_workers=num_workers) 
    

    a = [[],[]]
    for iteration, sample in enumerate(train_loader):
        img_gt, img_und, rawdata_und, masks, norm = sample
        img_gt = T.center_crop(T.complex_abs(img_gt), [320, 320]).unsqueeze(1)
        img_und = T.center_crop(T.complex_abs(img_und), [320, 320]).unsqueeze(1)
        a[0].append(img_und)
        a[1].append(img_gt)
    b = torch.cat(a[0][:])
    c = torch.cat(a[1][:])
train = torch.stack((b,c),dim=0)
del a
del b
del c
del train_loader
del train_dataset
train.shape

torch.Size([2, 2134, 1, 320, 320])

In [11]:
lr = 1e-3
    
network = AlexNet()
network.to('cuda:0') #move the model on the GPU
mse_loss = nn.MSELoss().to('cuda:0')
    
optimizer = optim.Adam(network.parameters(), lr=lr)
train_loader = DataLoader(train, shuffle=True, batch_size=1, num_workers=num_workers) 
for iteration, sample in enumerate(train_loader):
    #img_gt, img_und, rawdata_und, masks, norm = sample        
    
    output = network(img_und)       #feedforward
    print(output.shape)

    loss = mse_loss(output, img_gt)
    optimizer.zero_grad()       #set current gradients to 0
    loss.backward()      #backpropagate
    optimizer.step()     #update the weights
    print(loss.item(), "  ")
        
    i = 0
    j +=1
        
    if j%100 == 0:
        for row in range(0,320):
            for col in range(0,320):
                if output[0,0,row,col].item() == img_gt[0,0,row,col].item():

                        i +=1
        print(i, "\n \n")

ValueError: not enough values to unpack (expected 5, got 1)

False