# MRI EDA
## Packages

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import nibabel as nib
from ipywidgets import interact
from matplotlib.colors import ListedColormap
import os
import pandas as pd
import numpy as np
import torch
from torchvision import transforms

## Helper Functions

In [46]:
def load_data(mri_case,mri_type="t2_tse_fs_cor"):
    path = f"../raw_data/nii_files/{mri_type}/{mri_case}.nii"
    return nib.load(path).get_fdata()
    
def visualize_slices(mri_list,preprocess_slices=None):
    if preprocess_slices !=None:
        mri_list_processed = list(map(preprocess_slices,mri_list))
        n_slices = mri_list_processed[0].shape[2]
    else:
        mri_list_processed = None
    cmap = plt.cm.winter
    # Get the colormap colors
    my_cmap = cmap(np.arange(cmap.N))
    my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
    my_cmap = ListedColormap(my_cmap)

    # Function to visualize a single slice
    def show_slice(mri, slice_number):
        if mri_list_processed == None:
            plt.imshow(mri_list[mri][:, :, slice_number], cmap='gray')
            
        else:
            starting_point = mri_list[mri].shape[2]//2 - n_slices//2
            fig, axes = plt.subplots(1, 2, figsize=(10, 5))  # Create a figure with 2 subplots
            axes[0].imshow(mri_list[mri][:, :, slice_number], cmap='gray')
            axes[0].set_title(f'Not Preprocessed')
            if (starting_point <= slice_number) and (starting_point + n_slices > slice_number):
                axes[1].imshow(mri_list_processed[mri][:, :, slice_number-starting_point], cmap='gray')
                axes[1].set_title(f'Preprocessed')
        plt.show()

    
    interact(show_slice, mri=(0, len(mri_list) - 1), slice_number=(0, mri_list[0].shape[2] - 1))


def resize_and_crop_3d_image_cor(image, new_size=(384, 384), crop_size= (112, 112, 6)):
    resize_transform = transforms.Resize(new_size)

    # Process each slice
    resized_slices = []
    for slice_idx in range(image.shape[2]):
        # Extract the slice and add a channel dimension
        slice = image[:, :, slice_idx]
        slice = torch.tensor(slice).unsqueeze(0)  # Add a channel dimension
        resized_slice = resize_transform(slice)
        resized_slices.append(resized_slice.squeeze(0).numpy())

    resized_image = np.stack(resized_slices, axis=2)
    center = np.array(resized_image.shape) // 2
    cropped_image = resized_image[
        center[0]-crop_size[0]//2 : center[0]+crop_size[0]//2,
        center[1]-crop_size[1]//2 : center[1]+crop_size[1]//2,
        center[2]-crop_size[2]//2 : center[2]+crop_size[2]//2
    ]

    return cropped_image


In [None]:
def resize_and_crop_3d_image_seg(image, new_size=(384, 384), crop_size= (150, 150, 8)): # 112, 112, 6
    resize_transform = transforms.Resize(new_size)

    # Process each slice
    resized_slices = []
    for slice_idx in range(image.shape[2]):
        # Extract the slice and add a channel dimension
        slice = image[:, :, slice_idx]
        slice = torch.tensor(slice).unsqueeze(0)  # Add a channel dimension
        resized_slice = resize_transform(slice)
        resized_slices.append(resized_slice.squeeze(0).numpy())

    resized_image = np.stack(resized_slices, axis=2)
    center = np.array(resized_image.shape) // 2
    cropped_image = resized_image[
        center[0]-crop_size[0]//2 + 20 : center[0]+crop_size[0]//2 + 20,
        center[1]-crop_size[1]//2 - 20 : center[1]+crop_size[1]//2 - 20,
        center[2]-crop_size[2]//2 : center[2]+crop_size[2]//2
    ]

    return cropped_image

In [3]:
t2_tse_cor = load_data("7729409","t2_tse_cor")
visualize_slices([t2_tse_cor])

interactive(children=(IntSlider(value=0, description='mri', max=0), IntSlider(value=9, description='slice_numb…

In [4]:
case_879 = load_data("8797386")
resized_and_cropped_image = resize_and_crop_3d_image_cor(case_879)
visualize_slices([case_879],resize_and_crop_3d_image_cor)

interactive(children=(IntSlider(value=0, description='mri', max=0), IntSlider(value=7, description='slice_numb…

In [44]:
train_data = pd.read_csv("../data/train_data.csv")
cases = train_data["MRI_Case_ID"][17:25]
mri = [load_data(mri_case,"t1_tse_sag") for mri_case in cases]
# Sagital auf 8 erhöhen 140 x 140

In [47]:
visualize_slices(mri,resize_and_crop_3d_image_seg)

interactive(children=(IntSlider(value=3, description='mri', max=7), IntSlider(value=9, description='slice_numb…