In [1]:
import numpy as np
import pandas as pd
import nibabel as nib
import os

In [2]:

def create_region_df(region_files):
    df_rows = []

    # Process each region file
    for region_name, file_path in region_files.items():
        # Load the MRI data for the region
        region_data = np.array(nib.load(file_path).get_fdata(), dtype=np.int32)
        
        # Find unique voxel values and their indices
        unique_values, counts = np.unique(region_data, return_counts=True)
        
        # Iterate through each unique value to find its indices
        for value, count in zip(unique_values, counts):
            indices = list(zip(*np.where(region_data == value)))
            df_rows.append({
                "Region Name": region_name,
                "Region Unique Number": value,
                "Indices List of Voxels": indices,
                "Number of Voxels": count
            })
    
    # Create DataFrame
    region_df = pd.DataFrame(df_rows)
    return region_df



def create_and_save_masks(df, shape = (121,145,121)):
    # Ensure the output directory exists
        
    sample_img = nib.load('/data/users3/jchen/atlas/Jean_abcd/abcd_mask/final/roiFrontal_forSmri.nii')
    affine = sample_img.affine

    # Create a new column for the mask file paths
    df['Mask File Path'] = ''
    
    
    df['Delta X'] = 0
    df['Delta Y'] = 0
    df['Delta Z'] = 0
    df['Mask Center'] = ''
    df['Mni_Scaled'] = ''

    # Iterate through each row in the DataFrame to create and save masks
    for index, row in df.iterrows():
        #Create an empty mask with the given shape
        mask = np.zeros(shape, dtype=np.uint8)
        
        # Set the indices to 1
        for coord in row['Indices List of Voxels']:
            mask[coord[0], coord[1], coord[2]] = 1
            
        indices = np.array(row['Indices List of Voxels'])
        
        
    ## Get maximum spans in each dimension
        df.at[index, 'Delta X'] = indices[:, 0].max() - indices[:, 0].min()
        df.at[index, 'Delta Y'] = indices[:, 1].max() - indices[:, 1].min()
        df.at[index, 'Delta Z'] = indices[:, 2].max() - indices[:, 2].min()
        
        
    ## Get the centre of the non-zero voxels.(Will be our manual postional encoding)
        non_zero_coords = np.argwhere(mask)
        min_coords = non_zero_coords.min(axis=0)
        max_coords = non_zero_coords.max(axis=0) + 1  # Add 1 because slice end index is exclusive
        
        # Compute the center of the bounding box
        center = [(max_coord + min_coord) // 2 for min_coord, max_coord in zip(min_coords, max_coords)]
        df.at[index,'Mask Center'] = center
        
        # Convert to MNI Scale
        center.append(1)
        voxel = np.array(center)
        mni_coords = affine.dot(voxel)[:3]
        mni_coords_scaled = [round(mni_coords[0]/75,3) ,  round(mni_coords[1]/110,3) , round(mni_coords[2]/85, 3)]
        df.at[index, 'Mni_Scaled'] = mni_coords_scaled
        
        # # Define the file path for the mask
        # file_name = f"{row['Region Name']}_{row['Region Unique Number']}.nii"
        # file_path = os.path.join(output_dir, file_name)
        
        # # Save the mask as a NIfTI file
        # nib.save(nib.Nifti1Image(mask, affine), file_path)
        
        # # Update the DataFrame with the mask file path
        # df.at[index, 'Mask File Path'] = file_path
        
    df = df.drop('Indices List of Voxels', axis = 1)

    return df



In [3]:
def interactive_plot_image(img):
    """
    Create an interactive plot for a 3D image with the choice of using sliders or text boxes for input.
    Args:
    img (numpy.ndarray): The 3D image array.
    """
    
    def plot_image(x, y, z):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        slices = [img[int(x), :, :], img[:, int(y), :], img[:, :, int(z)]]
        
        for i, slice in enumerate(slices):
            axes[i].imshow(slice.T, cmap="gray", origin="lower")
            axes[i].axis("off")
        
        plt.show()

    
    x_input = widgets.IntSlider(min=0, max=img.shape[0]-1, step=1, value=img.shape[0]//2, description='X Slice:')
    y_input = widgets.IntSlider(min=0, max=img.shape[1]-1, step=1, value=img.shape[1]//2, description='Y Slice:')
    z_input = widgets.IntSlider(min=0, max=img.shape[2]-1, step=1, value=img.shape[2]//2, description='Z Slice:')


    # Bind the inputs to the plot_image function
    ui = widgets.VBox([x_input, y_input, z_input])
    out = widgets.interactive_output(plot_image, {'x': x_input, 'y': y_input, 'z': z_input})
    
    display(ui, out)


In [4]:


# Define file paths for each region
region_files = {
    "frontal": '/data/users3/jchen/atlas/Jean_abcd/abcd_mask/final/roiFrontal_forSmri.nii',
    "cereb": '/data/users3/jchen/atlas/Jean_abcd/abcd_mask/final/roiCerebellum_forSmri.nii',
    "RN": '/data/users3/jchen/atlas/Jean_abcd/abcd_mask/final/roiRN_forSmri.nii',
    "thalamus": '/data/users3/jchen/atlas/Jean_abcd/abcd_mask/final/roiThalamus_forSmri.nii',
    "parietal": '/data/users3/jchen/atlas/Jean_abcd/abcd_mask/final/roiParietal_forSmri.nii'
}

# Generate the DataFrame
region_df = create_region_df(region_files)
# Filter out rows where 'Region Unique Number' is 0
filtered_df = region_df[region_df['Region Unique Number'] != 0]


In [None]:
filtered_df

Unnamed: 0,Region Name,Region Unique Number,Indices List of Voxels,Number of Voxels,Mask File Path,Delta X,Delta Y,Delta Z,Mask Center,Mni_Scaled
1,frontal,32,"[(86, 70, 48), (86, 70, 49), (88, 72, 48), (89...",51,,14,18,10,"[93, 79, 49, 1]","[-0.66, -0.068, 0.018]"
2,frontal,33,"[(92, 73, 54), (92, 74, 54), (93, 73, 54), (93...",22,,10,7,4,"[97, 75, 54, 1]","[-0.74, -0.123, 0.106]"
3,frontal,34,"[(81, 67, 54), (81, 67, 55), (81, 67, 56), (81...",105,,8,5,10,"[85, 70, 54, 1]","[-0.5, -0.191, 0.106]"
4,frontal,35,"[(83, 71, 57), (83, 71, 58), (83, 71, 59), (83...",492,,7,14,12,"[87, 78, 55, 1]","[-0.54, -0.082, 0.124]"
5,frontal,36,"[(81, 67, 58), (81, 67, 59), (81, 68, 58), (81...",83,,8,6,4,"[85, 70, 59, 1]","[-0.5, -0.191, 0.194]"
...,...,...,...,...,...,...,...,...,...,...
368,parietal,362,"[(19, 53, 63), (19, 54, 63), (20, 53, 63), (20...",125,,0,0,0,,
369,parietal,363,"[(19, 43, 63), (19, 44, 63), (20, 42, 66), (20...",382,,0,0,0,,
370,parietal,364,"[(17, 47, 65), (17, 47, 66), (17, 47, 67), (17...",828,,0,0,0,,
371,parietal,365,"[(15, 53, 63), (15, 53, 65), (15, 53, 67), (15...",560,,0,0,0,,


In [6]:
updated_df = create_and_save_masks(filtered_df)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['Mask File Path'] = ''
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['Delta X'] = 0
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['Delta Y'] = 0
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in 

IndexError: index 125 is out of bounds for axis 0 with size 121

In [30]:
updated_df[updated_df['Region Name'] == 'cereb']

Unnamed: 0,Region Name,Region Unique Number,Number of Voxels,Mask File Path,Delta X,Delta Y,Delta Z,Mask Center,Mni_Scaled
169,cereb,1,1963,,28,30,41,"[47, 46, 27, 1]","[0.26, -0.518, -0.371]"
170,cereb,2,2140,,26,26,35,"[45, 48, 26, 1]","[0.3, -0.491, -0.388]"
171,cereb,4,1059,,25,29,35,"[48, 49, 22, 1]","[0.24, -0.477, -0.459]"
172,cereb,5,297,,21,31,27,"[47, 46, 20, 1]","[0.26, -0.518, -0.494]"
173,cereb,6,798,,23,27,33,"[48, 49, 23, 1]","[0.24, -0.477, -0.441]"
174,cereb,7,294,,22,26,18,"[47, 47, 14, 1]","[0.26, -0.505, -0.6]"
175,cereb,8,182,,18,15,16,"[45, 54, 15, 1]","[0.3, -0.409, -0.582]"
176,cereb,9,1237,,31,35,27,"[45, 44, 24, 1]","[0.3, -0.545, -0.424]"
177,cereb,10,2901,,37,39,39,"[42, 43, 24, 1]","[0.36, -0.559, -0.424]"
178,cereb,11,3254,,36,29,39,"[42, 46, 26, 1]","[0.36, -0.518, -0.388]"


In [32]:
updated_df.to_csv('3D_Masks_Data_parietal.csv', index=False)

In [34]:
## Testing the masks

frontal_gt = nib.load('/data/users3/jchen/atlas/Jean_abcd/abcd_mask/final/roiFrontal_forSmri.nii').get_fdata()