In [6]:
from MNIST_dataloader import Noisy_MNIST
from Fast_MRI_dataloader import Fast_MRI
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from torch.utils.data import TensorDataset
import glob
import numpy as np
from tqdm.auto import tqdm 
from PIL import Image
import numpy as np
from torch.fft import fft2, fftshift, ifft2, ifftshift

In [7]:
# %% dataloader for the Fast MRI dataset
def create_dataloaders_mri(data_loc, batch_size):
    dataset_train = Fast_MRI("train", data_loc)
    dataset_test  = Fast_MRI("test" , data_loc)
    
    Fast_MRI_train_loader =  DataLoader(dataset_train, batch_size=batch_size, shuffle=True,  drop_last=False)
    Fast_MRI_test_loader  =  DataLoader(dataset_test , batch_size=batch_size, shuffle=True, drop_last=False)
    
    return Fast_MRI_train_loader, Fast_MRI_test_loader

In [8]:
# Bram
#data_loc = 'D://5LSL0-Datasets//Fast_MRI_Knee' #change the datalocation to something that works for you
#Amin 
data_loc = 'C:/Users/amin2/Documents/School/5LSL0ML/5LSL0/data/Fast_MRI_Knee'


# define parameters
batch_size = 8

train_loader, test_loader = create_dataloaders_mri(data_loc, batch_size)

In [9]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self, ).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.LeakyReLU()
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.deconv2 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.deconv1 = nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=3, stride=1, padding=1)


    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.deconv2(x)
        x = self.deconv1(x)
        return x

In [14]:
# Define function to achieve Full k-space
def get_k_space(inputs):
    k_space = fftshift(fft2(inputs))
    return k_space

# Define function to achieve Partial k-space from Full k-space and Mask
def get_partial_k_space(input,M):
    return  torch.mul(input, M)

def get_accelerate_MRI(inputs):
    return ifft2(inputs)

In [16]:
model = ConvNet()

criterion = nn.MSELoss()
optimizer = optim.Adam( model.parameters(), lr=0.001)
num_epochs = 10
#unfolding steps
num_iterations = 5
prox_operator = nn.ModuleList([ConvNet() for _ in range(num_iterations)])
mu = 0.1

In [38]:
for epoch in range(num_epochs):
    train_loss = 0.0
    count = 0

    # Training loop
    model.train()
    loop = tqdm(train_loader)
    loop.set_description(f"Epoch [{epoch}/{num_epochs}]")
    for i,(kspace, M, gt) in enumerate(loop):
        gt_label = gt.unsqueeze(1)
        kspace_input = kspace.unsqueeze(1)
        acc_mri = torch.abs(ifft2(kspace_input))
        x_t = acc_mri
        optimizer.zero_grad()

        for iter in range(num_iterations):
            x_t = prox_operator[iter](x_t)
            F_x = get_k_space(x_t)

            k_space_y = get_k_space(acc_mri)
            F_x = torch.squeeze(F_x, dim=1) 
            k_space_y = torch.squeeze(k_space_y, dim=1) 
            z = F_x - mu * get_partial_k_space(F_x, M) + mu * get_partial_k_space(k_space_y, M)
            z = torch.unsqueeze(z, dim=1) 
            x_t = torch.abs(get_accelerate_MRI(z))
            loss = criterion(x_t, gt_label)
            loss.backward()
            optimizer.step()

        count+=1
        train_loss += loss.item() #* kspace_input.size(0)
        loop.set_postfix(loss=loss.item())
    count_test = 0
    test_loss = 0
    # Testing loop
    model.eval()
    with torch.no_grad():
       loop = tqdm(test_loader)
       for i,(kspace_t, M_t, gt_t) in enumerate(loop):
            gt_label_t = gt_t.unsqueeze(1)
            kspace_input_t = kspace_t.unsqueeze(1)
            acc_mri_t = torch.abs(ifft2(kspace_input_t))
            x_t = acc_mri_t

            for iter in range(num_iterations):
                x_t = model(x_t)
                F_x = get_k_space(x_t)

                k_space_y = get_k_space(acc_mri_t)
                F_x = torch.squeeze(F_x, dim=1) 
                k_space_y = torch.squeeze(k_space_y, dim=1) 
                z = F_x - mu * get_partial_k_space(F_x, M) + mu * get_partial_k_space(k_space_y, M)
                z = torch.unsqueeze(z, dim=1) 
                x_t = torch.abs(get_accelerate_MRI(z))
            loss = criterion(x_t, gt_label_t)
            count_test +=1
            test_loss += loss.item() #* kspace_input.size(0)
            loop.set_postfix(loss=loss.item())

    # Calculate average losses
    train_loss /= count
    test_loss /= count_test

    # train_losses.append(train_loss)
    # test_losses.append(test_loss)

    # Print epoch-wise loss
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}")

# Save the trained model   
torch.save(model.state_dict(), 'Ex6_trained.pth')

Epoch [0/10]:  71%|███████▏  | 67/94 [05:51<02:21,  5.24s/it, loss=0.151]


KeyboardInterrupt: 