# Coursework for MRI reconstruction (Autumn 2019)

In this tutorial, we provide the data loader to read and process the MRI data in order to ease the difficulty of training your network. By providing this, we hope you focus more on methodology development. Please feel free to change it to suit what you need.

In [1]:
import h5py, os
from functions import transforms as T
from functions.subsample import MaskFunc
from scipy.io import loadmat
from torch.utils.data import DataLoader
import numpy as np
import torch
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
def show_slices(data, slice_nums, cmap=None): # visualisation
    fig = plt.figure(figsize=(15,10))
    for i, num in enumerate(slice_nums):
        plt.subplot(1, len(slice_nums), i + 1)
        plt.imshow(data[num], cmap=cmap)
        plt.axis('off')

In [3]:
class MRIDataset(DataLoader):
    def __init__(self, data_list, acceleration, center_fraction, use_seed):
        self.data_list = data_list
        self.acceleration = acceleration
        self.center_fraction = center_fraction
        self.use_seed = use_seed

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        subject_id = self.data_list[idx]
        return get_epoch_batch(subject_id, self.acceleration, self.center_fraction, self.use_seed)

In [4]:
def get_epoch_batch(subject_id, acc, center_fract, use_seed=True):
    ''' random select a few slices (batch_size) from each volume'''

    fname, rawdata_name, slice = subject_id  
    
    with h5py.File(rawdata_name, 'r') as data:
        rawdata = data['kspace'][slice]
                      
    slice_kspace = T.to_tensor(rawdata).unsqueeze(0)
    S, Ny, Nx, ps = slice_kspace.shape

    # apply random mask
    shape = np.array(slice_kspace.shape)
    mask_func = MaskFunc(center_fractions=[center_fract], accelerations=[acc])
    seed = None if not use_seed else tuple(map(ord, fname))
    mask = mask_func(shape, seed)
      
    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), slice_kspace)
    masks = mask.repeat(S, Ny, 1, ps)

    img_gt, img_und = T.ifft2(slice_kspace), T.ifft2(masked_kspace)
    
    
    # perform data normalization which is important for network to learn useful features
    # during inference there is no ground truth image so use the zero-filled recon to normalize
    norm = T.complex_abs(img_und).max()
    if norm < 1e-6: norm = 1e-6
    
    # normalized data
    img_gt, img_und, rawdata_und = img_gt/norm, img_und/norm, masked_kspace/norm
    
#    img_gt = T.center_crop(T.complex_abs(img_gt), [320, 320]).unsqueeze(1)
#    img_und = T.center_crop(T.complex_abs(img_und), [320, 320]).unsqueeze(1)
#     rawdata_und = T.center_crop(T.complex_abs(rawdata_und), [320, 320]).unsqueeze(1)
#     norm = T.center_crop(T.complex_abs(norm), [320, 320]).unsqueeze(1)
#     masks.T.center_crop(T.complex_abs(masks), [320, 320]).unsqueeze(1)    

    img_gt = T.center_crop(T.complex_abs(img_gt), [320, 320])
    img_und = T.center_crop(T.complex_abs(img_und), [320, 320])
        
    return img_gt.squeeze(0), img_und.squeeze(0)


In [5]:
def load_data_path(train_data_path, val_data_path):
    """ Go through each subset (training, validation) and list all 
    the file names, the file paths and the slices of subjects in the training and validation sets 
    """

    data_list = {}
    train_and_val = ['train', 'val']
    data_path = [train_data_path, val_data_path]
      
    for i in range(len(data_path)):

        data_list[train_and_val[i]] = []
        
        which_data_path = data_path[i]
    
        for fname in sorted(os.listdir(which_data_path)):
            
            subject_data_path = os.path.join(which_data_path, fname)
                     
            if not os.path.isfile(subject_data_path): continue 
            
            with h5py.File(subject_data_path, 'r') as data:
                num_slice = data['kspace'].shape[0]
                
            # the first 5 slices are mostly noise so it is better to exlude them
            data_list[train_and_val[i]] += [(fname, subject_data_path, slice) for slice in range(5, num_slice)]
    
    return data_list    

In [6]:
class AlexNet(nn.Module):

    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), #320/320
            nn.Dropout2d(),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), #320/320
            nn.Dropout2d(),
            nn.ReLU(inplace=True),
            nn.Conv2d(64 ,64, kernel_size=3, padding=1),  # 320/320
            nn.Dropout2d(),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), #320/320
            nn.Dropout2d(),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),  # 320/320
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),  # 320/320
            
            
        )

    def forward(self, x):
        x = self.features(x)
        #x = nn.functional.sigmoid(x)
        #x = x * 255
        #x = x.type(torch.cuda.int32)
        return x

In [7]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])



In [8]:
from skimage.measure import compare_ssim 
def ssim(gt, pred):
    """ Compute Structural Similarity Index Metric (SSIM). """
    return compare_ssim(
        gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max()
    )

In [9]:


if __name__ == '__main__':
    
    data_path_train = '/tmp/NC2019MRI/train'
    data_path_val = '/tmp/NC2019MRI/train'
    data_list = load_data_path(data_path_train, data_path_val) # first load all file names, paths and slices.
    
    acc = 8
    cen_fract = 0.04
    seed = False # random masks for each slice 
    num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    
    lr = 1e-2
    
    network = AlexNet()
    network.to('cuda:0') #move the model on the GPU
    mse_loss = nn.SmoothL1Loss().to('cuda:0')
    
    optimizer = optim.Adagrad(network.parameters(), lr=lr)
    
    #create data loader for training set. It applies same to validation set as well
    train_dataset = MRIDataset(data_list['train'], acceleration=acc, center_fraction=cen_fract, use_seed=seed)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=22, num_workers=num_workers) 
    

    for epoch in range(5):
        for iteration, sample in enumerate(train_loader):
        
            img_gt, img_und = sample
        
            img_gt = img_gt.unsqueeze(1).to('cuda:0')
            img_und = img_und.unsqueeze(1).to('cuda:0')

            output = network(img_und)       #feedforward
            #output = output.squeeze(1).cpu().detach().numpy()

            loss = mse_loss(output, img_gt)
            #loss = torch.tensor(ssim(img_gt, output)).to('cuda:0')
            #print(loss.item())
            optimizer.zero_grad()       #set current gradients to 0
            loss.backward()      #backpropagate
            optimizer.step()     #update the weights

            print(loss.item(), "  ")
        

        
        # stack different slices into a volume for visualisation
#         A = masks[...,0].squeeze()
#         B = torch.log(T.complex_abs(rawdata_und) + 1e-9).squeeze()
#         C = T.complex_abs(img_und).squeeze()
#         D = T.complex_abs(img_gt).squeeze()
#         all_imgs = torch.stack([A,B,C,D], dim=0)

#         # from left to right: mask, masked kspace, undersampled image, ground truth
#         show_slices(all_imgs, [0, 1, 2, 3], cmap='gray')
#         plt.pause(1)

#         if iteration >= 0: break  # show 4 random slices
        

0.07518192380666733   
9.327569961547852   
0.06144632399082184   
0.05618124082684517   
0.04870830103754997   
0.026030097156763077   
0.03983459994196892   
0.026752814650535583   
0.03928736597299576   
0.024190006777644157   
0.18752729892730713   
0.04776955395936966   
0.030598556622862816   
0.03932809829711914   
0.029066724702715874   
0.03421145677566528   
0.02866198867559433   
0.020785290747880936   
0.03168277069926262   
0.027459660544991493   
0.01922103948891163   
0.026630135253071785   
0.03134298324584961   
0.023488536477088928   
0.023812556639313698   
0.024823004379868507   
0.023182714357972145   
0.017839733511209488   
0.020840751007199287   
0.02427021414041519   
0.029289713129401207   
0.021374840289354324   
0.01830027624964714   
0.022052261978387833   
0.019012024626135826   
0.028029121458530426   
0.019736481830477715   
0.018501287326216698   
0.019669819623231888   
0.017037028446793556   
0.024067388847470284   
0.021896136924624443   
0.021419195

0.011401934549212456   
0.0119841443374753   
0.009917051531374454   
0.010840573348104954   
0.011116345413029194   
0.008670862764120102   
0.010109918192029   
0.01038879994302988   
0.01097245141863823   
0.00920596532523632   
0.011301460675895214   
0.009846222586929798   
0.00863924901932478   
0.010996202938258648   
0.011646910570561886   
0.010230579413473606   
0.011165080592036247   
0.009015867486596107   
0.008371284231543541   
0.009074992500245571   
0.009173386730253696   
0.007803079206496477   
0.011048378422856331   
0.009998255409300327   
0.007006182800978422   
0.009458058513700962   
0.016204319894313812   
0.014772916212677956   
0.010573266074061394   
0.009375123307108879   
0.011076550930738449   
0.011700423434376717   
0.01072125043720007   
0.009597443975508213   
0.009798936545848846   
0.012078200466930866   
0.00984059926122427   
0.010668319649994373   
0.008705221116542816   
0.01007887627929449   
0.011641093529760838   
0.009905527345836163   
0.00

In [10]:
image, gt = train_dataset[3]
image = image.unsqueeze(0).to('cuda:0')
image = image.unsqueeze(0)
#gt = gt.unsqueeze(0).to('cuda:0')
gt = gt.unsqueeze(0).numpy()
output = network(image)
output = output.squeeze(1).cpu().detach().numpy()
loss = torch.tensor(ssim(gt, output))
loss2 = torch.tensor(ssim(gt, image.squeeze(1).cpu().numpy()))
print(loss.item())
print(loss2.item())
#loss2 = mse_loss(output, gt)
len(train_dataset)

0.6400861740112305
0.340620756149292


  """


2134

In [11]:
e = []
a=[]
b=[]
i = 0
for i in range(0,len(train_dataset)):
    image, gt = train_dataset[i]
    image = image.unsqueeze(0).to('cuda:0')
    image = image.unsqueeze(0)
    #gt = gt.unsqueeze(0).to('cuda:0')
    gt = gt.unsqueeze(0).numpy()
    output = network(image)
    output = output.squeeze(1).cpu().detach().numpy()
    loss = torch.tensor(ssim(gt, output))
    loss2 = torch.tensor(ssim(gt, image.squeeze(1).cpu().numpy()))
    e.append(loss.item()-loss2.item())
    a.append(loss.item())

    if loss.item()-loss2.item() < 0:
        i+=1
print(np.nanmean(e))
print(np.nanmean(a))

  """


0.13082477027915188
0.4987707764482878


In [14]:
image, gt = train_dataset[3]
image = image.unsqueeze(0).to('cuda:0')
image = image.unsqueeze(0)
gt = gt.unsqueeze(0).to('cuda:0')
gt = gt.unsqueeze(0)
output = network(image)
loss = mse_loss(output, gt)
loss2 = mse_loss(image, gt)
print(loss.item())
print(loss2.item())

0.011205554008483887
0.0060116928070783615


In [None]:
acc = 8
cen_fract = 0.04
seed = False # random masks for each slice 
num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    

if __name__ == '__main__':
    
    data_path_train = '/tmp/NC2019MRI/train'
    data_path_val = '/tmp/NC2019MRI/train'
    data_list = load_data_path(data_path_train, data_path_val) # first load all file names, paths and slices.
    
    acc = 8
    cen_fract = 0.04
    seed = False # random masks for each slice 
    num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    # create data loader for training set. It applies same to validation set as well
    train_dataset = MRIDataset(data_list['train'], acceleration=acc, center_fraction=cen_fract, use_seed=seed)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1, num_workers=num_workers) 
    

    a = [[],[]]
    for iteration, sample in enumerate(train_loader):
        img_gt, img_und, rawdata_und, masks, norm = sample
        img_gt = T.center_crop(T.complex_abs(img_gt), [320, 320]).unsqueeze(1)
        img_und = T.center_crop(T.complex_abs(img_und), [320, 320]).unsqueeze(1)
        a[0].append(img_und)
        a[1].append(img_gt)
    b = torch.cat(a[0][:])
    c = torch.cat(a[1][:])
train = torch.stack((b,c),dim=0)
del a
del b
del c
del train_loader
del train_dataset
train.shape

In [11]:
lr = 1e-3
    
network = AlexNet()
network.to('cuda:0') #move the model on the GPU
mse_loss = nn.MSELoss().to('cuda:0')
    
optimizer = optim.Adam(network.parameters(), lr=lr)
train_loader = DataLoader(train, shuffle=True, batch_size=1, num_workers=num_workers) 
for iteration, sample in enumerate(train_loader):
    #img_gt, img_und, rawdata_und, masks, norm = sample        
    
    output = network(img_und)       #feedforward
    print(output.shape)

    loss = mse_loss(output, img_gt)
    optimizer.zero_grad()       #set current gradients to 0
    loss.backward()      #backpropagate
    optimizer.step()     #update the weights
    print(loss.item(), "  ")
        
    i = 0
    j +=1
        
    if j%100 == 0:
        for row in range(0,320):
            for col in range(0,320):
                if output[0,0,row,col].item() == img_gt[0,0,row,col].item():

                        i +=1
        print(i, "\n \n")

ValueError: not enough values to unpack (expected 5, got 1)

In [12]:
a = torch.tensor([2])
a.astype(np.float64)

AttributeError: 'Tensor' object has no attribute 'astype'