In [None]:
# [ACTIVE] Personal Dataset creation

# Import statements
import fastmri
from fastmri.data import transforms as T

import h5py

import os

import csv

import numpy as np

import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader

import torch.nn.functional as F

import torchvision
from torchvision.io import read_image

import matplotlib
from matplotlib import pyplot as plt

from tqdm.notebook import tqdm

In [None]:
# Set the directory for the data
directory = "../../kadotab/Datasets/16_chans"

In [None]:
# Method to visualize coils, to be used later
def show_coils(data, slice_nums, cmap=None):
    fig = plt.figure()
    for i, num in enumerate(slice_nums):
        plt.subplot(1, len(slice_nums), i + 1)
        plt.imshow(data[num], cmap=cmap)
        plt.close(fig) # Close the figure window

In [None]:
# Method to save and display coil images, to be used later
def save_coils(path, filename, i, j, data, slice_nums, cmap=None):
    fig = plt.figure()
    for i, num in enumerate(slice_nums):
        plt.subplot(1, len(slice_nums), i + 1)
        plt.imshow(data[num], cmap=cmap)
        plt.savefig(path + filename + str(i) + "_" + str(j) + ".png")
        plt.close(fig) # Close the figure window

In [None]:
# Prepare the kspace data

# Create a csv file with all filenames for parsing through the dataset later
# Open the file and truncate it (clears the file)
f = open("fastMRI_filenames.csv",'w+')

# Write the filenames from the desired folder into the csv file
w = csv.writer(f)
for path, dirs, files in os.walk("../kadotab/Datasets/16_chans/train"):
    for filename in files:
        w.writerow([filename])

In [None]:
# Load the k-space for each individual coil per slice as an image file and save it in a new directory

# Set the directories
csvDir = 'fastMRI_filenames.csv'
rootDir = '../kadotab/Datasets/16_chans/train'
pathDir = 'fastMRI_kspace_images/'

# Create an initial dataframe to read
# frame = pd.read_csv(w)

# Iterate through the dataframe and create png files for the data
with open(csvDir, 'r') as file:
    csv_reader = csv.reader(file)
    for row in csv_reader:
        # Access the first value in each row and set it as the directory to the kspace file
        # Check that the correct file is being accessed
        # print(row[0])
        name = os.path.join(rootDir, row[0])

        # Access the individual file
        file = h5py.File(name)
        
        # Access the k-space numpy array
        sample = file['kspace'][()]

        # Want to save each coil per slice as an image, and this will be the new dataset
        # There are 16 slices and 16 coils

        # Iterate through each slice in the scan
        count1 = 0 # Initialize counter variables for naming
        # Take the i-th slice of the scan
        for slice in sample:
            # slice = sample[i] # Take the i-th slice of the scan
            slice_kspace = T.to_tensor(slice) # Convert from numpy array to pytorch tensor
            slice_image = fastmri.ifft2c(slice_kspace) # Apply Inverse Fourier Transform to get the complex image
            slice_image_abs = fastmri.complex_abs(slice_image) # Compute absolute value to get a real image
            count2 = 0 # Initialize counter variables for indexing
            for coil in slice:
                # Checkpoint: Show the j-th coil for the i-th slice
                show_coils(slice_image_abs, [count2], cmap='gray')

                # Save the coil images as "ij.png"
                save_coils(pathDir, row[0], count1, count2, slice_image_abs, [count2], cmap='gray')
                # Increment the counter by 1 to produce a new filename
                count2+=1
            count1+=1 # Increment the counter by 1 to produce a new filename

In [None]:
# Prepare the image data

# Create a csv file with all filenames for parsing through the dataset later
# Open the file and truncate it (clears the file)
f1=open("fastMRI_images_filenames.csv",'w+')

# Write the filenames from the desired folder into the csv file
w1=csv.writer(f1)
for path, dirs, files in os.walk("fastMRI_kspace_images/"):
    for filename in files:
        w.writerow([filename])

In [None]:
# Create a custom dataset

import pandas as pd

class fastMRICustomDataset(Dataset):
    """Custom fastMRI dataset for loading k-space data."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with filenames.
            root_dir (string): Directory with all the fastMRI image data.
            transform (callable, optional):  Transform to be applied on a sample.
        """
        self.dataframe = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # Read the individual file
        data_name = os.path.join(self.root_dir,
                                self.dataframe.iloc[idx, 0])
        
        # file = h5py.File(data_name)

        # # Access the k-space numpy array
        # sample = file['kspace'][()]

        # # Apply toTensor transform to the data if requested
        # sample = T.to_tensor(sample)

        # Simply return the image
        image = read_image(data_name)

        return image

In [None]:
# Attempting to load the dataset of images

# Set the directories
csv_dir = 'fastMRI_images_filenames.csv'
root_dir = 'fastMRI_kspace_images/'

# Set up the dataset
kspace_dataset = fastMRICustomDataset(
    csv_file=csv_dir,
    root_dir=root_dir,
)

# Check the dataset shape (passed)
# sample = kspace_dataset[0]
# print(sample.dtype)
# print(sample.shape)
# Output: torch.float32
# torch.Size([16, 16, 640, 320, 2])

In [None]:
# Setting up the dataloader
dataloader = DataLoader(kspace_dataset, batch_size=4, shuffle=True, num_workers=8)
# The dataloader should now load coil images

# NEW: Dataloader should load images
# Check output data
train_img = next(iter(dataloader))
img = train_img[0].squeeze()
plt.imshow(img, cmap="gray")
plt.show()

# Check dataloader data (passed)
# print(dataloader.dataset[0].dtype)
# print(dataloader.dataset[0].shape)
# Output: torch.float32
# torch.Size([16, 16, 640, 320, 2])

In [None]:
# Access one k-space sample and show coils

# Convert back from tensor to numpy array
# kspace_nparray = T.tensor_to_complex_np(dataloader.dataset[0])

# Choose the 8th slice
# slice_kspace = kspace_nparray[8]

# Show coils 0, 5, and 10 (passed, see OneNote for output)
# show_coils(np.log(np.abs(slice_kspace) + 1e-9), [0, 5, 10])

# **Building the Diffusion Model**

From the Google Colab tutorial: https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL?usp=sharing#scrollTo=hqcoJ8ZlXE1i

## Step 1: The forward process = Noise scheduler

We first need to build the inputs for our model, which are more and more noisy images. Instead of doing this sequentially, we can use the closed form provided in the papers to calculate the image for any of the timesteps individually.

**Key Takeaways:**

- The noise-levels/variances can be pre-computed
- There are different types of variance schedules
- We can sample each timestep image independently (Sums of Gaussians is also Gaussian)
- No model is needed in this forward step

In [None]:
# Generic Simple LDM Code (not functional yet)

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """ 
    Takes an image and a timestep as input and 
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

In [None]:
# Test it on the dataset
from torchvision import transforms 
from torch.utils.data import DataLoader
import numpy as np

IMG_SIZE = 64
BATCH_SIZE = 128

def load_transformed_dataset():
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), # Scales data into [0,1] 
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
    ]
    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.StanfordCars(root=".", download=True, 
                                         transform=data_transform)

    test = torchvision.datasets.StanfordCars(root=".", download=True, 
                                         transform=data_transform, split='test')
    return torch.utils.data.ConcatDataset([train, test])
def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :] 
    plt.imshow(reverse_transforms(image))

data = load_transformed_dataset()
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [None]:
# Simulate forward diffusion
image = next(iter(dataloader))[0]

plt.figure(figsize=(15,15))
plt.axis('off')
num_images = 10
stepsize = int(T/num_images)

for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64)
    plt.subplot(1, num_images+1, int(idx/stepsize) + 1)
    img, noise = forward_diffusion_sample(image, t)
    show_tensor_image(img)

## Step 2: The backward process = U-Net

For a great introduction to UNets, have a look at this post: https://amaarora.github.io/2020/09/13/unet.html.

**Key Takeaways:**

- We use a simple form of a UNet for to predict the noise in the image
- The input is a noisy image, the ouput the noise in the image
- Because the parameters are shared accross time, we need to tell the network in which timestep we are
- The Timestep is encoded by the transformer Sinusoidal Embedding
- We output one single value (mean), because the variance is fixed

In [None]:
from torch import nn
import math


class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 3 
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])
        
        # Edit: Corrected a bug found by Jakub C (see YouTube comment)
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model

**Further improvements that can be implemented:**

- Residual connections
- Different activation functions like SiLU, GWLU, ...
- BatchNormalization
- GroupNormalization
- Attention
- ...

## Step 3: The loss

**Key Takeaways:**

- After some maths we end up with a very simple loss function
- There are other possible choices like L2 loss ect.

In [None]:
def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred)

### Sampling

- Without adding @torch.no_grad() we quickly run out of memory, because pytorch tacks all the previous images for gradient calculation
- Because we pre-calculated the noise variances for the forward pass, we also have to use them when we sequentially perform the backward process

In [None]:
@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        # As pointed out by Luis Pereira (see YouTube comment)
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def sample_plot_image():
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        # Edit: This is to maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            show_tensor_image(img.detach().cpu())
    plt.show()            

### Training

In [None]:
from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = Adam(model.parameters(), lr=0.001)
epochs = 100 # Try more!

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
      loss = get_loss(model, batch[0], t)
      loss.backward()
      optimizer.step()

      if epoch % 5 == 0 and step == 0:
        print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
        sample_plot_image()