# Imports

In [1]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
from random import randint
import glob
import os
import csv

# Basic Helper Functions

In [None]:
def get_rand_slice_list(data_shape):
    x_max, y_max, z_max = data_shape
    x_curr = randint((x_max/2)-(x_max/4), (x_max/2)+(x_max/4))
    y_curr = randint((y_max/2)-(y_max/4), (y_max/2)+(y_max/4))
    z_curr = randint((z_max/2)-(z_max/4), (z_max/2)+(z_max/4))
    return x_curr, y_curr, z_curr



def show_mri_slices_random(mri_data, explicit_pos=None):
    """ Function to display random image slices """
    '''Provision to give exact slice numbers'''
    '''Random numbers biased towards middle'''
    
    print('Data Shape = ',mri_data.shape)
    if explicit_pos==None:
        x_curr, y_curr, z_curr = get_rand_slice_list(mri_data.shape)
    else:
        x_curr, y_curr, z_curr = explicit_pos
    print('Data Positions = ',x_curr, y_curr, z_curr)
    slice_0 = mri_data[x_curr, :, :]
    slice_1 = mri_data[:, y_curr, :]
    slice_2 = mri_data[:, :, z_curr]
    print('Slice 1: value: ',x_curr)
    plt.imshow(slice_0.T, cmap='gray', origin=0)
    plt.show()
    print('Slice 2: value: ',y_curr)
    plt.imshow(slice_1.T, cmap='gray', aspect=0.5, origin=0)
    plt.show()
    print('Slice 3: value: ',z_curr)
    plt.imshow(slice_2.T, cmap='gray', aspect=0.5, origin=0)
    plt.show()
    
    
    
def BrainMasker(brain_data,skull_data, brain_obj=None,isSave=False):
    maskedObj = np.ma.masked_array(brain_data,(skull_data+1)%2, fill_value=0)
    h = maskedObj.filled()
    if isSave:
        if brain_obj==None:
            print('No affine transform available. provide brain object')
        else:
            new_image = nib.Nifti1Image(h, brain_obj.affine)
            nib.save(new_image, "output.nii.gz")
    return h



# Dataset Loading Functions

In [None]:
dataset_path = 'F:\\7thsemProjects\\MRIAnalysis\\3dmrMS'

In [None]:
def load_MS_dataset_paths(MS_dataset_path):
    actual_data={}
    list_of_set_of_datasets =glob.glob(MS_dataset_path+'/*/')
    whole_list_of_datasets=[]
    for set_of_datasets in list_of_set_of_datasets:
        list_of_datasets = glob.glob(set_of_datasets+'/*/')
        whole_list_of_datasets.extend(list_of_datasets)
    #print((whole_list_of_datasets))
    for curr_data_path in whole_list_of_datasets:
        curr_dataset={}
        curr_dataset['brainmask'] = glob.glob(curr_data_path+'/*brainmask.nii.gz')[-1]
        curr_dataset['segmentation'] = glob.glob(curr_data_path+'/*consensus_gt.nii.gz')[-1]
        curr_dataset['t1w'] = glob.glob(curr_data_path+'/*T1W.nii.gz')[-1]
        curr_dataset['t1w_enhance'] = glob.glob(curr_data_path+'/*T1WKS.nii.gz')[-1]
        curr_dataset['t2w'] = glob.glob(curr_data_path+'/*T2W.nii.gz')[-1]
        curr_dataset['flair'] = glob.glob(curr_data_path+'/*FLAIR.nii.gz')[-1]
        for name,val in curr_dataset.items():
            if val==[]:
                print('Error at',name,': ',curr_data_path)
                return
        dataset_name_parts = curr_data_path.split(os.sep)
        dataset_name_parts.pop()
        dataset_name = dataset_name_parts.pop()
        
        actual_data[dataset_name]=curr_dataset
        
    return actual_data 

def load_dataset_details(MS_dataset_path):
    file=MS_dataset_path + '\patient26-30\patient_info.csv'
    dict_list = []
    with open(file) as fh:
        rd = csv.DictReader(fh, delimiter=',') 
        for row in rd:
            dict_list.append(dict(row))
    return dict_list

def list_dataset(datasets_path_list):
    for i, val in datasets_path_list.items():
        print(i,val)
        
        
def show_all_data_patient(patient_datapath_dict):
    t1w_obj = nib.load(patient_datapath_dict['t1w'])
    t2w_obj = nib.load(patient_datapath_dict['t2w'])
    flair_obj = nib.load(patient_datapath_dict['flair'])
    skull_obj = nib.load(patient_datapath_dict['brainmask'])
    concensus_obj = nib.load(patient_datapath_dict['segmentation'])
    
    t1w_data = t1w_obj.get_fdata()
    t2w_data = t2w_obj.get_fdata()
    flair_data = flair_obj.get_fdata()
    concensus_data = concensus_obj.get_fdata()
    skull_data = skull_obj.get_fdata()
    
    x,y,z = get_rand_slice_list(t1w_data.shape)
    
    print('t1w')
    show_mri_slices_random(t1w_data,(x,y,z))
    print('t2w')
    show_mri_slices_random(t2w_data,(x,y,z))
    print('flair')
    show_mri_slices_random(flair_data,(x,y,z))
    print('concensus')
    show_mri_slices_random(concensus_data,(x,y,z))
    print('skull data')
    show_mri_slices_random(skull_data,(x,y,z))
    
    print('t1w masked')
    t1w_masked=BrainMasker(t1w_data,skull_data)
    show_mri_slices_random(t1w_masked,(x,y,z))
    print('t2w masked')
    t2w_masked=BrainMasker(t2w_data,skull_data)
    show_mri_slices_random(t2w_masked,(x,y,z))
    print('flair masked')
    flair_masked=BrainMasker(flair_data,skull_data)
    show_mri_slices_random(flair_masked,(x,y,z))
    

# Process PipeLine

1. Dataset Loading
2. Dataset Viewing
3. Image Processing
    a. Skull Stripping
    b. Intensity Normalization - https://github.com/loli/medpy/blob/master/medpy/filter/IntensityRangeStandardization.py
    c. Size Standardization
4. Goal #1
    Getting segmented images using flair mri scans only
    
    
5. Scoring in Segmentation
    Dice Scoring
