# About

If 2D analysis also need to be done, this notebook can be used.

Main purpose of this notebook is simply retrieving 2D slices from 3D .nii files to utilize for 2D image classification tasks.<br/>
-> Find desired slice interval for desired axis for particular subject id.<br/>
-> Open 3 more files named AD, CN and MCI to save slices under train, test and val sets.

Ultimately, we will have '.jpg', '.png' or even '.pdf' (you can choose desired format) under files in particular format for PyTorch dataset lib is also expected.

In [1]:
import os
from pathlib import Path
from nibabel.testing import data_path
import nibabel as nib

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

In [2]:
class utils:
    def __init__(self):
        return
    

    def read_image(self, img_path):
        return nib.load(os.path.join(data_path, img_path)).get_fdata()


    ''' Find image in the given path '''
    def take_mri(self, subject_id, path):
        for root, dirs, files in os.walk(path):
            for file in files:
                if subject_id in file and file.endswith(".nii"):
                    return os.path.join(root, file)
            

    # TODO: try to normalize slices if possible
    #       .nii file sizes are different.
            # 43.5mb : (256, 256, 166)
            # 23.6mb : (192, 192, 160)
            # 47.2mb : (256, 256, 180) and more different sizes...

    def get_slices(self, subject_id, path, dataset, group, axis, region, width, iter, format):

        img_path = self.take_mri(subject_id, os.path.join(path, dataset))
        img = self.read_image(img_path)

        if region == 'mid': ## If iter = 1: slice count = width*2
            if (img.shape[0] > 200 or img.shape[1] > 200) and img.shape[2] >= 166:
                start = 120 - width; end = 120 + width; sagittal_start = 90 - width; sagittal_end = 90 + width
            else:
                start = 85 - width; end = 85 + width; sagittal_start = 85 - width; sagittal_end = 85 + width
        elif region == 'mid_to_back': ## If iter = 1: slice count = width
            if (img.shape[0] > 200 or img.shape[1] > 200) and img.shape[2] >= 166:
                start = 120 - width; end = 120; sagittal_start = 90 - width; sagittal_end = 90
            else:
                start = 80 - width; end = 80; sagittal_start = 80 - width; sagittal_end = 80
        else: #mid_to_forward ----> If iter = 1: slice count = width
            if (img.shape[0] > 200 or img.shape[1] > 200) and img.shape[2] >= 166:
                start = 120; end = 120 + width; sagittal_start = 90; sagittal_end = 90 + width
            else:
                start = 85; end = 85 + width; sagittal_start = 85; sagittal_end = 85 + width


        
        for i in range(0, end - start, iter):
            if axis == 'axial':
                slice = img[start + i,:,:]
            elif axis == 'coronal':
                slice = img[:,start + i,:]
            else:
                slice = img[:,:,sagittal_start + i]

            save_path = f'{path}/{dataset}/{group}/'
            os.makedirs(save_path, 0o755, True)
            plt.imsave(f'{save_path}{subject_id}-{axis}Slice{start + i}{format}', slice, cmap='gray')




    def run(self, data, path, dataset, axis = 'axial', region = 'mid', width = 5, iter = 1, format = '.jpg'):
        assert dataset in ['train', 'test', 'val']
        assert format in ['.png', '.jpg', '.pdf']
        assert axis in ['axial', 'sagittal', 'coronal', 'all']
        assert region in ['mid_to_back', 'mid_to_forward', 'mid']
        assert width > 0 and width <= 50 and isinstance(width, int), f"width greater than 0 and less equal than 50 is expected, got: {width}"
        assert iter > 0 and iter <= 5 and isinstance(iter, int), f"iter greater than 0 and less equal than 5 is expected, got: {iter}"
        
        # DO NOT FORGET THIS RANGE IS NORMALLY -> len(data['subject'])
        for i in range(len(data['subject'])):
            if axis == 'axial':
                self.get_slices(subject_id = data['subject'][i], path = path, dataset = dataset, 
                    group = data['group'][i], axis = 'axial', region = region, width = width, iter = iter, format = format)
            elif axis == 'coronal':
                self.get_slices(subject_id = data['subject'][i], path = path, dataset = dataset, 
                    group = data['group'][i], axis = 'coronal', region = region, width = width, iter = iter, format = format)
            elif axis == 'sagittal':
                self.get_slices(subject_id = data['subject'][i], path = path, dataset = dataset, 
                    group = data['group'][i], axis = 'sagittal', region = region, width = width, iter = iter, format = format)
            else: # 'all'
                self.get_slices(subject_id = data['subject'][i], path = path, dataset = dataset, 
                    group = data['group'][i], axis = 'axial', region = region, width = width, iter = iter, format = format)
                self.get_slices(subject_id = data['subject'][i], path = path, dataset = dataset, 
                    group = data['group'][i], axis = 'coronal', region = region, width = width, iter = iter, format = format)
                self.get_slices(subject_id = data['subject'][i], path = path, dataset = dataset, 
                    group = data['group'][i], axis = 'sagittal', region = region, width = width, iter = iter, format = format)

In [3]:
u = utils()
path = '/Users/toygar/Desktop/Bitirme/data/'

In [4]:
train = pd.read_csv(os.path.join(path, 'train/train.csv'))
val = pd.read_csv(os.path.join(path, 'val/val.csv'))
test = pd.read_csv(os.path.join(path, 'test/test.csv'))

In [66]:
# u.run(train, path, 'train', axis = 'axial', region='mid', width=5, iter=1, format = '.jpg')
# u.run(train, path, 'train', axis = 'coronal', region='mid', width=5, iter=1, format = '.jpg')
# u.run(train, path, 'train', axis = 'sagittal', region='mid', width=5, iter=1, format = '.jpg')

In [5]:
u.run(train, path, 'train', axis = 'all', region='mid', width=12, iter=4, format = '.jpg')
u.run(val, path, 'val', axis = 'all', region='mid', width=12, iter=4, format = '.jpg')
u.run(test, path, 'test', axis = 'all', region='mid', width=12, iter=4, format = '.jpg')