In [2]:
import os 
import nibabel as nib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ants
from glob import glob
from tqdm import tqdm

In [2]:
import numpy as np
import ants
from nibabel.nifti1 import Nifti1Image

def get_ras_affine_from_ants(ants_img):
    """
    Parameters
    ----------
    ants_img : ants.ANTsImage
        The ANTs image whose affine is to be converted.

    Returns
    -------
    affine : np.ndarray
        The affine matrix in RAS coordinates.
    """
    spacing = np.array(ants_img.spacing)
    direction_lps = np.array(ants_img.direction)
    origin_lps = np.array(ants_img.origin)
    direction_length = direction_lps.shape[0] * direction_lps.shape[1]
    if direction_length == 9:
        rotation_lps = direction_lps.reshape(3, 3)
    elif direction_length == 4:  # 2D case (1, W, H, 1)
        rotation_lps_2d = direction_lps.reshape(2, 2)
        rotation_lps = np.eye(3)
        rotation_lps[:2, :2] = rotation_lps_2d
        spacing = np.append(spacing, 1)
        origin_lps = np.append(origin_lps, 0)
    elif direction_length == 16:  # Fix potential bad NIfTI
        rotation_lps = direction_lps.reshape(4, 4)[:3, :3]
        spacing = spacing[:-1]
        origin_lps = origin_lps[:-1]
    else:
        raise NotImplementedError(f"Unexpected direction length = {direction_length}.")

    rotation_ras = np.dot(np.diag([-1, -1, 1]), rotation_lps)
    rotation_ras_zoom = rotation_ras * spacing
    translation_ras = np.dot(np.diag([-1, -1, 1]), origin_lps)

    affine = np.eye(4)
    affine[:3, :3] = rotation_ras_zoom
    affine[:3, 3] = translation_ras

    return affine


def nifti_to_ants(nib_image: "Nifti1Image"):
    """

    Parameters
    ----------
    nib_image : Nifti1Image
        The Nifti image to be converted.

    Returns
    -------
    ants_image : ants.ANTsImage
        The converted ANTs image.
    """
    ndim = nib_image.ndim

    if ndim < 3:
        raise NotImplementedError(
            "Conversion is only implemented for 3D or higher images."
        )
    q_form = nib_image.get_qform()
    spacing = nib_image.header["pixdim"][1 : ndim + 1]

    origin = np.zeros(ndim)
    origin[:3] = np.dot(np.diag([-1, -1, 1]), q_form[:3, 3])

    direction = np.eye(ndim)
    direction[:3, :3] = np.dot(np.diag([-1, -1, 1]), q_form[:3, :3]) / spacing[:3]

    ants_img = ants.from_numpy(
        data=nib_image.get_fdata(),
        origin=origin.tolist(),
        spacing=spacing.tolist(),
        direction=direction,
    )
    "add nibabel conversion (lacey import to prevent forced dependency)"

    return ants_img

def ants_to_nifti(img, header=None):
    """
    Parameters
    ----------
    img : ants.ANTsImage
        The ANTs image to be converted.
    header : Nifti1Header, optional
        Optional header to use for the Nifti image.

    Returns
    -------
    img : Nifti1Image
        The converted Nifti image.
    """
    from nibabel.nifti1 import Nifti1Image

    affine = get_ras_affine_from_ants(img)
    arr = img.numpy()

    if header is not None:
        header.set_data_dtype(arr.dtype)

    return Nifti1Image(arr, affine, header)


In [None]:
print('standard')
standard = nib.load('C:/Users/BREIN/Desktop/Research/brain_atlases/MIITRA/MIITRA-T1w-05mm.nii.gz')
standard_ants = ants.image_read('C:/Users/BREIN/Desktop/Research/brain_atlases/MIITRA/MIITRA-T1w-05mm.nii.gz')
print(standard.shape)
print(standard.header.get_zooms())

print('dkt 0.5mm')
dkt = nib.load('C:/Users/BREIN/Desktop/Research/brain_atlases/MIITRA/DKT-05mm/DKT-05mm.nii.gz')
dkt_ants = ants.image_read('C:/Users/BREIN/Desktop/Research/brain_atlases/MIITRA/DKT-05mm/DKT-05mm.nii.gz')
print(dkt.shape)
print(dkt.header.get_zooms())

print('tpm')
tpm = nib.load('C:/Users/BREIN/Desktop/Research/ext_pckgs/matlab/spm12/tpm/TPM.nii')
# tpm_ants = ants.image_read('C:/Users/BREIN/Desktop/Research/ext_pckgs/matlab/spm12/tpm/TPM.nii')
tpm_arr = tpm.get_fdata()
tpm_brain = (1.0*tpm_arr[:,:,:,0]**1.2) + (1.4*tpm_arr[:,:,:,1]**1.5) + (tpm_arr[:,:,:,2]*0.5) + (tpm_arr[:,:,:,3]*0.25) + (tpm_arr[:,:,:,4]*0.25) 
tpm_brain = nib.Nifti1Image(tpm_brain, tpm.affine, tpm.header)
nib.save(tpm_brain, 'tpm_brain.nii.gz')
print(tpm.shape)
print(tpm.header.get_zooms())

tpm_ants = ants.image_read('tpm_brain.nii.gz')
miitra_2_tpm = ants.registration(fixed=standard_ants, moving=tpm_ants, type_of_transform='SyN')
warped_atlas = ants.apply_transforms(fixed=tpm_ants, moving=dkt_ants, transformlist=miitra_2_tpm['invtransforms'], interpolator='genericLabel')
ants.image_write(warped_atlas,'warped_atlas2.nii.gz')



standard
(380, 440, 380)
(0.5, 0.5, 0.5)
dkt 0.5mm
(380, 440, 380)
(0.5, 0.5, 0.5)
tpm


In [4]:
import ants
def squeeze_img(inp_img):
    img = inp_img.get_fdata()
    if img.ndim == 4 and img.shape[3] == 1:
        out_img = nib.Nifti1Image(img[:,:,:,0], inp_img.affine, inp_img.header)
        return out_img
    else: 
        return inp_img

img_dir = 'Z:/1_combined/UCSF/PROC/T1/GM_Density_wmap'
out_dir = 'Z:/1_combined/UCSF/PROC/T1/native_dkt'
ucsf_demo = pd.read_excel('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/ucsf_demo.xlsx')

dkt_arr = nib.load('warped_atlas.nii.gz').get_fdata().astype(int)
labels = np.unique(dkt_arr)
labels = labels[labels != 0]

rows = []

print(len(os.listdir(img_dir)))
for img_path in os.listdir(img_dir):
    subj = img_path.split('_')[0]
    print(subj)
    subj_num = int(subj.split('-')[1])
    print(subj_num)

    dx = ucsf_demo.loc[ucsf_demo['PIDN']==subj_num, 'FTLD_MMCsubtype_combined'].iloc[0]
    print(dx)

    img = nib.load(os.path.join(img_dir, img_path))
    img = squeeze_img(img)
    img_arr = img.get_fdata()

    print(img.shape)
    print(img.header.get_zooms())

    row = {
        "PTID" : subj_num, "DX" : dx
    }

    for l in labels:
        mask = dkt_arr == l
        row[f"{l}"] = (
            np.nanmean(img_arr[mask]) if np.any(mask) else np.nan
        )
    
    rows.append(row)
    
final_df = pd.DataFrame(rows)

165
sub-00278
278
C
(121, 145, 121)
(1.5, 1.5, 1.5)
sub-12670
12670
ALS
(121, 145, 121)
(1.5, 1.5, 1.5)
sub-08008
8008
B
(121, 145, 121)
(1.5, 1.5, 1.5)
sub-02795
2795
B
(121, 145, 121)
(1.5, 1.5, 1.5)
sub-06757
6757
B
(121, 145, 121)
(1.5, 1.5, 1.5)
sub-11442
11442
Pick's
(121, 145, 121)
(1.5, 1.5, 1.5)
sub-01201
1201
AD


KeyboardInterrupt: 

In [6]:
dx_list = ["Pick's", "CBD", "PSP", "A", "B", "C", "D", "U", "AD"]
final_df = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/ucsf_regional_wscore.csv')
dx_filtered_df = final_df[final_df['DX'].isin(dx_list)]
print(len(final_df))
print(len(dx_filtered_df))
print(dx_filtered_df['DX'].value_counts())


130
130
DX
B         31
AD        18
A         17
CBD       16
C         14
Pick's    14
U         10
PSP       10
Name: count, dtype: int64


In [7]:
from atlas_vis import DKTAtlas62ROIPlotter
plotter_62  = DKTAtlas62ROIPlotter(
    cmap='Reds',
    clim=(-0.5, 6),  
    window_size=(1200, 1000),
    nan_color='lightgray',
    background='white',
    template_key='pial'
)

for dx in dx_filtered_df['DX'].unique():
    subgroup_df = dx_filtered_df[dx_filtered_df['DX'] == dx]
    print(dx)
    print(len(subgroup_df))
    X_raw = subgroup_df.loc[:,'1002':'2035'].values.astype(float)
    X_mean = X_raw.mean(axis=0)
    print(len(X_mean))

    l_values = X_mean[:31].tolist()
    r_values = X_mean[31:].tolist()
    print(len(l_values))
    print(len(r_values))
    print(np.min(l_values + r_values))
    print(np.max(l_values + r_values))

    plotter_62(l_values, r_values, save_path=f'./ucsf_wscore_surface_map/{dx}.png')

C
14
62
31
31
0.48011502685714297
6.094212692357145


Python-Atlas-Visualization-main/atlas_vis/core/__init__.py:98: Argument 'mesh' must be passed as a keyword argument to function 'BasePlotter.update_scalars'.
From version 0.50, passing this as a positional argument will result in a TypeError.
  self.plotter.update_scalars(l_value, mesh)
Python-Atlas-Visualization-main/atlas_vis/core/__init__.py:100: Argument 'mesh' must be passed as a keyword argument to function 'BasePlotter.update_scalars'.
From version 0.50, passing this as a positional argument will result in a TypeError.
  self.plotter.update_scalars(r_value, mesh)


B
31
62
31
31
0.406867683516129
2.4934854560645165
Pick's
14
62
31
31
-0.19773706900000002
5.102006410214287
AD
18
62
31
31
0.8573132713888889
2.5402041798333332
U
10
62
31
31
0.8456794592
2.6915207335
A
17
62
31
31
0.48364690164705887
3.30765488
PSP
10
62
31
31
0.3185475692
1.9715625228999998
CBD
16
62
31
31
0.30346685925000005
2.1959366478125
