In [13]:
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
from skimage import measure
from skimage import io
import cv2
import nibabel as nib

In [3]:
data_df = pd.read_csv('/home/prateek/from_kipchoge/ms_project/participants.csv')
data_df

Unnamed: 0,participant_id,sessions,age,observed_change_t2,observed_change_t3,observed_change_t4,observed_change_t5,time_dfference_t21,time_difference_t32,time_difference_t43,time_difference_t54
0,sub-ID13,3,34,1,0.0,,,438,154.0,,
1,sub-ID141,4,46,1,0.0,0.0,,406,252.0,105.0,
2,sub-ID24,3,38,1,1.0,,,333,133.0,,
3,sub-ID39,4,37,1,1.0,0.0,,450,193.0,273.0,
4,sub-ID88,2,69,1,,,,613,,,
...,...,...,...,...,...,...,...,...,...,...,...
165,sub-ID72,3,33,1,1.0,,,371,360.0,,
166,sub-ID153,2,55,1,,,,595,,,
167,sub-ID75,4,38,1,0.0,0.0,,372,193.0,262.0,
168,sub-ID16,3,28,0,1.0,,,361,357.0,,


In [4]:
def load_nifti(filename, header=False, verbose=False):
    """Load nifti files"""
    nifti = nib.load(filename)
    if verbose:
        print(' File loaded from:', filename)
        print(' File dimensions : ', nifti.shape)
        print(' Data type :', nifti.get_data_dtype())
    if header:
        return nifti.get_fdata(), nifti.header
    return nifti.get_fdata()

In [5]:
np.random.seed(55)
data_df.sample(frac=1)
patients = data_df.participant_id.values
patients

array(['sub-ID13', 'sub-ID141', 'sub-ID24', 'sub-ID39', 'sub-ID88',
       'sub-ID45', 'sub-ID26', 'sub-ID69', 'sub-ID122', 'sub-ID42',
       'sub-ID101', 'sub-ID71', 'sub-ID25', 'sub-ID84', 'sub-ID145',
       'sub-ID112', 'sub-ID41', 'sub-ID50', 'sub-ID149', 'sub-ID156',
       'sub-ID65', 'sub-ID151', 'sub-ID9', 'sub-ID68', 'sub-ID87',
       'sub-ID20', 'sub-ID104', 'sub-ID83', 'sub-ID52', 'sub-ID147',
       'sub-ID143', 'sub-ID11', 'sub-ID60', 'sub-ID3', 'sub-ID18',
       'sub-ID74', 'sub-ID119', 'sub-ID57', 'sub-ID110', 'sub-ID114',
       'sub-ID99', 'sub-ID48', 'sub-ID137', 'sub-ID159', 'sub-ID31',
       'sub-ID6', 'sub-ID152', 'sub-ID12', 'sub-ID8', 'sub-ID59',
       'sub-ID36', 'sub-ID134', 'sub-ID108', 'sub-ID116', 'sub-ID14',
       'sub-ID55', 'sub-ID140', 'sub-ID170', 'sub-ID167', 'sub-ID10',
       'sub-ID85', 'sub-ID106', 'sub-ID130', 'sub-ID118', 'sub-ID66',
       'sub-ID34', 'sub-ID125', 'sub-ID102', 'sub-ID54', 'sub-ID165',
       'sub-ID29', 'sub-ID40', 'sub-I

In [6]:
num_train_pats = 110
num_valid_pats = 30
num_test_pats = 30

In [7]:

training_patients = patients[:num_train_pats]
validation_patients = patients[num_train_pats:(num_train_pats+num_valid_pats)]
testing_patients = patients[(num_train_pats+num_valid_pats):]

print(f"training patients : {len(training_patients)}\n{training_patients}")
print(f"validation patients : {len(validation_patients)}\n{validation_patients}")
print(f"testing patients : {len(testing_patients)}\n{testing_patients}")

training patients : 110
['sub-ID13' 'sub-ID141' 'sub-ID24' 'sub-ID39' 'sub-ID88' 'sub-ID45'
 'sub-ID26' 'sub-ID69' 'sub-ID122' 'sub-ID42' 'sub-ID101' 'sub-ID71'
 'sub-ID25' 'sub-ID84' 'sub-ID145' 'sub-ID112' 'sub-ID41' 'sub-ID50'
 'sub-ID149' 'sub-ID156' 'sub-ID65' 'sub-ID151' 'sub-ID9' 'sub-ID68'
 'sub-ID87' 'sub-ID20' 'sub-ID104' 'sub-ID83' 'sub-ID52' 'sub-ID147'
 'sub-ID143' 'sub-ID11' 'sub-ID60' 'sub-ID3' 'sub-ID18' 'sub-ID74'
 'sub-ID119' 'sub-ID57' 'sub-ID110' 'sub-ID114' 'sub-ID99' 'sub-ID48'
 'sub-ID137' 'sub-ID159' 'sub-ID31' 'sub-ID6' 'sub-ID152' 'sub-ID12'
 'sub-ID8' 'sub-ID59' 'sub-ID36' 'sub-ID134' 'sub-ID108' 'sub-ID116'
 'sub-ID14' 'sub-ID55' 'sub-ID140' 'sub-ID170' 'sub-ID167' 'sub-ID10'
 'sub-ID85' 'sub-ID106' 'sub-ID130' 'sub-ID118' 'sub-ID66' 'sub-ID34'
 'sub-ID125' 'sub-ID102' 'sub-ID54' 'sub-ID165' 'sub-ID29' 'sub-ID40'
 'sub-ID7' 'sub-ID115' 'sub-ID17' 'sub-ID132' 'sub-ID28' 'sub-ID127'
 'sub-ID76' 'sub-ID161' 'sub-ID19' 'sub-ID96' 'sub-ID33' 'sub-ID131'
 'sub-ID1

In [8]:
training_df = data_df[:num_train_pats]
validation_df = data_df[num_train_pats:(num_train_pats+num_valid_pats)]
testing_df = data_df[(num_train_pats+num_valid_pats):]


In [18]:
def extract_square_bounding_box(bounding_box, margin):
    top, left, bottom, right = bounding_box
    width = right - left
    height = bottom - top
    square_size = max(width, height) + margin
    center_x = (left + right) // 2
    center_y = (top + bottom) // 2
    top = center_y - square_size // 2
    bottom = center_y + square_size // 2
    left = center_x - square_size // 2
    right = center_x + square_size // 2
    return top, left, bottom, right

def get_bbox(mask, margin=30):
    x,y,z = mask.nonzero()
    num_pixels = np.sum(np.sum(mask,axis=0),axis=0)
    largest_slice = mask[:,:,np.argmax(num_pixels)]
    props = measure.regionprops(largest_slice.astype(int))[0]
    return extract_square_bounding_box(props.bbox, margin)


def crop_image(image, bbox):
    top, left, bottom, right = bbox
    cropped_image = image[top:bottom, left:right]
    return cropped_image


def rescale_and_resize(image, target_size, rescale=True, equalize=False, rotate=False):
    image = np.where(image<0, 0.1, image)
    if rescale:
        image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    if image.shape != target_size:
        image = cv2.resize(image.astype(np.uint8), target_size, interpolation = cv2.INTER_LINEAR)
    if equalize:
        clahe = cv2.createCLAHE(clipLimit=3)
        image = clahe.apply(image)
    if rotate:
        image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)

    return image

def save_slice(directory, data, filename):
    file_path = directory+filename
    if cv2.imwrite(file_path, data):
        return filename
    else:
        print(f"Issue writing file {file_path}")

def strip_small_lesions(target_slice, x_len, y_len):
    labels = measure.label(target_slice, background=0)
    props = pd.DataFrame(measure.regionprops_table(labels, properties=('label','bbox','axis_major_length','axis_minor_length')))
    props['remove'] = [(x<x_len) and (y<y_len) for x,y in props.loc[:,['axis_major_length','axis_minor_length']].values]
    labels_to_remove = list(props[props.remove==True].label.values)
    labels_to_remove.append(0) # adding background label to prevent background turning white
    new_labels = np.isin(labels, labels_to_remove, invert=True)
    return new_labels*1


def extract_slice(patient_id, data_directory, target_directory, time_A, prefix,
                   instance=0, target_size=(512,512), no_lesion_sampling=None, lesion_size=3,
                  file_pattern = "{}/{}/ses-{}/anat/{}_ses-{}_{}.nii.gz"):
    flair_A, header = load_nifti(file_pattern.format(data_directory,
                                            patient_id, time_A,
                                            patient_id, time_A,
                                            "FLAIR"), header=True)
    t1_A = load_nifti(file_pattern.format(data_directory,
                                            patient_id, time_A,
                                            patient_id, time_A,
                                            "T1w"))
    lesion_A = load_nifti(file_pattern.format(data_directory,
                                            patient_id, time_A,
                                            patient_id, time_A,
                                            "ground_truth"))
    mask = load_nifti(file_pattern.format(data_directory,
                                            patient_id, time_A,
                                            patient_id, time_A,
                                            "defacemask"))
    
    lesions = (lesion_A > 0)*1
    # extract slices with lesions 
    x_l, y_l, z_l = lesions.nonzero()
    # extract all non background slices
    x, y, z = mask.nonzero()
    slices = np.unique(z)
    num_lesion_slices = len(np.unique(z_l))
    
    p_write = num_lesion_slices/(len(slices))
    num_slices = 0 
    data = []
    pix_x , pix_y = header['pixdim'][1:3]
    x_len, y_len = lesion_size/pix_x, lesion_size/pix_y
    
    bbox = get_bbox(mask,margin=18)
    
    for z_slice in slices:
        
        target_slice = lesions[:,:,z_slice]
        target_slice = strip_small_lesions(target_slice, x_len, y_len)
        change = 0
        write_slice = 0
        if np.sum(target_slice) > 0:
            change = 1
            write_slice=1
        else:
            if no_lesion_sampling == 'same':
                write_slice = np.random.choice([0,1], p = [1-p_write, p_write])
            elif no_lesion_sampling == 'all':
                write_slice = 1 # select all
            else:
                write_slice = 0 # ignore all 
        #print(f'{change}, {write_slice}, {num_lesion_slices}, {num_slices_no_lesions}')
        if write_slice:
            
            fA = crop_image(flair_A[:,:,z_slice], bbox)
            t1A = crop_image(t1_A[:,:,z_slice], bbox)
            labels = crop_image(target_slice, bbox)
            
            slice_A = rescale_and_resize(fA, target_size, equalize=True, rotate=True)
            sliceT1_A = rescale_and_resize(t1A, target_size, equalize=True, rotate=True)
            slice_L = rescale_and_resize(labels, target_size, rotate=True)
            flair_slice_name = f"{prefix}_flair_{patient_id}_{time_A}_{num_slices}.png"
            t1_slice_name = f"{prefix}_t1_{patient_id}_{time_A}_{num_slices}.png"
            slice_name = f"{prefix}_label_{patient_id}_{time_A}_{num_slices}.png"
            path_f = save_slice(target_directory+'image/', slice_A, flair_slice_name)
            path_t = save_slice(target_directory+'image/', sliceT1_A, t1_slice_name)
            path_L = save_slice(target_directory+'label/', slice_L, slice_name)
            data.append({'patient_id': patient_id,
               'path_flair': path_f.replace(target_directory, ''),
               'path_t1' : path_t.replace(target_directory, ''),
               'label' : path_L.replace(target_directory, ''),
               'slice' : z_slice,
               'time' : time_A,
               'change': change,
               'original' : target_slice.shape,
                'cropped': labels.shape,
                'bbox' : bbox,
               'size': target_size[0]})
            num_slices +=1

    return pd.DataFrame(data)

def build_2d_slices(data_df, data_directory, target_directory, prefix, no_lesion_sampling='all',
                         target_size=(512,512)):
    data = []
    for patient_id, grp in tqdm(data_df.groupby('participant_id')):
        num_sessions = grp['sessions'].values[0]
        for tm in range(num_sessions):
            data.append(extract_slice(patient_id, data_directory,
                                                  target_directory, tm+1,
                                                  prefix = prefix,
                                                  target_size=target_size,
                                                  no_lesion_sampling=no_lesion_sampling))
    return pd.concat(data)

In [22]:
data_directory = '/home/prateek/from_kipchoge/ms_project/data'
target_directory = '/home/prateek/ms_project/ms_slice_data/'
train_slices = build_2d_slices(training_df, data_directory, target_directory,
                             prefix='train', no_lesion_sampling='all', target_size=(256,256))

100%|██████████| 110/110 [29:56<00:00, 16.33s/it]


In [20]:
valid_slices = build_2d_slices(validation_df, data_directory, target_directory,
                             prefix='valid', no_lesion_sampling='all', target_size=(256,256))

100%|██████████| 30/30 [07:27<00:00, 14.92s/it]


In [21]:
test_slices = build_2d_slices(testing_df, data_directory, target_directory,
                             prefix='test', no_lesion_sampling='all', target_size=(256,256))

100%|██████████| 30/30 [06:50<00:00, 13.69s/it]


In [23]:
test_slices.to_csv(f'{target_directory}/test.csv',index=False)
train_slices.to_csv(f'{target_directory}/train.csv',index=False) 
valid_slices.to_csv(f'{target_directory}/valid.csv',index=False)