In [None]:
from pathlib import Path
import numpy as np
import os
from dipy.io.streamline import load_tractogram
import nibabel as nib
import matplotlib.pyplot as plt
import seaborn as sns
import re
import pandas as pd
import itertools as it
from dipy.segment.clustering import QuickBundles
from dipy.segment.metric import AveragePointwiseEuclideanMetric, ResampleFeature
import dipy.stats.analysis as dsa

import dipy.tracking.streamline as dts

In [None]:
layout = BIDSLayout("/scratch/knavynde/topsy/", derivatives=True)

In [None]:
SLURM_TMPDIR = Path(os.environ.get("SLURM_TMPDIR"))
dwi_path = Path("/project/6033503/cfmm-bids/Palaniyappan/TOPSY/baseline/correct7T.v0.3.1/prepdwi_v0.0.12c/work/sub-001/dwi/uncorrected_denoise_unring_topup_eddy_regT1/dwi.nii.gz")
t1_path = Path('/scratch/knavynde/topsy/sub-001/anat/sub-001_space-orig_desc-preproc_T1w.nii.gz')
t1_map_path = Path('/scratch/knavynde/topsy/sub-001/anat/sub-001_acq-MP2RAGE_run-01_T1map.nii.gz')
fa_path = Path('/project/6033503/cfmm-bids/Palaniyappan/TOPSY/baseline/correct7T.v0.3.1/prepdwi_v0.0.12c/work/sub-001/dwi/uncorrected_denoise_unring_topup_eddy_regT1/dti_FA.nii.gz')
tracts_path = SLURM_TMPDIR/'tck-tracts'
clusters_left = SLURM_TMPDIR/'tck-clusters/left'
clusters_right = SLURM_TMPDIR/'tck-clusters/right'

assert tracts_path.exists() and clusters_left.exists() and clusters_right.exists()

In [None]:
t1map = nib.load(t1_map_path)

In [None]:
sup_f_l_path = tracts_path/"T_Sup-F_left.tck"
tracts = load_tractogram(str(rand_cluster), str(t1_path))

In [None]:
streamlines = tracts.streamlines
tracts_actor = actor.line(streamlines, cmap.line_colors(streamlines))

In [None]:
data = dwi.get_fdata()[..., 0]
mean, std = data[data > 0].mean(), data[data > 0].std()
value_range = (mean - 0.5 * std, mean + 1.5 * std)
slice_actor = actor.slicer(data, dwi.affine, value_range)

In [None]:
scene = window.Scene()
#slice_actor.display(z=25)
scene.add(tracts_actor)
#scene.add(slice_actor)
scene.reset_camera()
scene.elevation(130)
img = window.snapshot(scene, size=(1000,1000))
plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(img)

In [None]:
class TractProfile:
    def __init__(self, streamlines, ref):
        cluster = load_tractogram(str(streamlines), str(ref))
        if cluster is False:
            raise Exception(f"Cluster {streamlines} could not be loaded")

        if not cluster:
            self.streamlines = None
            return
        feature = ResampleFeature(nb_points=100)
        metric = AveragePointwiseEuclideanMetric(feature)

        qb = QuickBundles(np.inf, metric=metric)
        cluster_bundle = qb.cluster(cluster.streamlines)

        self.cluster_bundle = cluster_bundle
        self.streamlines = cluster.streamlines

        self.weights = dsa.gaussian_weights(self.streamlines)


    def get_profile(self, img):
        return dsa.afq_profile(
            img.get_fdata(),
            dts.orient_by_streamline(
                self.streamlines,
                self.cluster_bundle.centroids[0]
            ),
            img.affine,
            weights=self.weights
        )


def get_profiles_and_streamlines(paths, param_maps, ref):
    profiles = np.empty((len(param_maps), len(paths), 100))
    streamlines = np.empty(len(paths))

    for i, path in enumerate(paths):
        profile = TractProfile(path, ref)
        if profile.streamlines:
            streamlines[i] = len([*profile.streamlines])
            for j, param_map in enumerate(param_maps.values()):
                profiles[j, i] = profile.get_profile(param_map)
        else:
            streamlines[i] = 0
            profiles[:, i, :] = 0
    return profiles.mean(axis=2), streamlines


In [None]:
data = clusters_left
ref_img = t1_path

parameter_maps = {
    "FA": nib.load(fa_path)
}

if data.is_dir():
    paths = [*data.glob("*.tck")][:25]
else:
    raise FileNotFoundError("Input must be a directory")

profiles, streamlines = get_profiles_and_streamlines(paths, parameter_maps, ref_img)



In [None]:
cluster_numbers = (
    re.search(r'(?<=cluster_)\d{5}(?=\.tck$)', str(path))[0] for path in paths
)
profile_table = pd.DataFrame(
    {
        key: data for key, data in zip(parameter_maps, profiles)
    }
).assign(
    cluster=pd.Series(cluster_numbers),
    subject=1,
    streamlines=pd.Series(streamlines)
).set_index(["subject", "cluster"])

profile_table
with Path("test.csv").open('w') as f:
    profile_table.to_csv(f)

In [None]:
dataframes = (
    pd.read_csv(path, index_col=0) for path in ["test.csv", "test.csv"]
)

merged = pd.concat(dataframes)
with Path("all_test.csv").open('w') as f:
    merged.to_csv(f)

In [None]:
import dipy.stats.analysis as dsa

weighted_cluster = dsa.gaussian_weights(oriented_cluster)

In [None]:
fa = nib.load(fa_path)

profile = TractProfile(clusters_left/"cluster_00733.tck", t1_path).get_profile(fa)

In [None]:
plt.figure()
axes = plt.axes()
axes.set_ylabel("Fractional anisotropy")
axes.set_xlabel("Node along CST")
plt.plot(profile)


In [None]:
data: pd.DataFrame = pd.read_csv("results/tract_profiles.csv", index_col=(0,1))
data.head()

In [None]:
import pickle
with Path("resources/layer_assignments.pyc").open('rb') as f:
    layer_dict = pickle.load(f)

with Path("resources/tract_assignments.pyc").open('rb') as f:
    tracts_dict = pickle.load(f)

def get_cluster_number(name: str):
    return int(re.search(r"\d{5}(?=\.vtp)", name)[0])

hem_clusters = layer_dict["hemispheric"]
left_hem = [get_cluster_number(path) for path in hem_clusters]
right_hem = [left + 800 for left in left_hem]


In [None]:
categories = [
    1,
    1,
    1,
    1,
    1,
    3,
    2,
    2,
    3,
    1,
    1,
    2,
    2,
    2,
    3,
    2,
    2,
    2,
    2,
    2,
    3,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    3,
    3,
    2,
    2,
    1,
    1,
    2,
    2,
    2,
    1,
    1,
    2,
    1,
    2,
    2,
    1,
    2,
    2,
    1,
    2,
    2,
    2,
    1,
    1,
    1,
    1,
    2,
    2,
    2,
    2,
    1,
    1,
    1,
    1,
    1,
    1,
    2,
    1,
    1,
    2,
    3,
    2,
    2,
    2,
    2,
    2,
    1,
    2,
    2,
    3,
    2,
    2,
    2,
    2,
    1,
    2,
    2,
    1,
    2,
    4,
    1,
    1,
    2,
    2,
    2,
    4,
    2,
    4,
    2,
    4,
    2,
    2,
    2,
    2,
    2,
    2,
    4,
    4,
    4,
    4,
    4,
    2,
    4,
    2,
    4,
    2,
    2,
    4,
    4,
    2,
    3,
    3,
    2,
    2,
    2,
    2,
    4,
    3,
    3,
    3,
    3,
    4,
    1,
    3,
    4,
    2
]

In [None]:
superficial = [
    "T_Sup-F",
    "T_Sup-FP",
    "T_Sup-O",
    "T_Sup-OT",
    "T_Sup-P",
    "T_Sup-PO",
    "T_Sup-PT",
    "T_Sup-T"
]
assoc = [
    "T_AF",
    "T_CB",
    "T_EC",
    "T_EmC",
    "T_ILF",
    "T_IoFF",
    "T_MdLF",
    "T_PLIC",
    "T_SLF-I",
    "T_SLF-II",
    "T_SLF-III",
    "T_UF"
]

sup_tracts = [
        get_cluster_number(x) for x in it.chain.from_iterable(
        [
            tracts_dict[k] for k in superficial if k in tracts_dict
        ]
    )
]
sup_tracts = sup_tracts + [tract + 800 for tract in sup_tracts]
assoc_tracts = [
    get_cluster_number(x) for x in it.chain.from_iterable(
        {
            k: tracts_dict[k] for k in assoc if k in tracts_dict
        }.values()
    )
]
assoc_tracts = assoc_tracts + [tract + 800 for tract in assoc_tracts]

In [None]:
# Ratio
idx = pd.IndexSlice
grouped = data.assign(group=lambda x: [categories[y] for y, _ in x.index] )
assoc_data = data.loc[idx[:, assoc_tracts], :]



sup_HC = grouped.loc[(grouped["group"] == 1) & (grouped["streamlines"] > 0)].loc[idx[:, sup_tracts], :]
sup_FEP = grouped.loc[(grouped["group"] == 2) & (grouped["streamlines"] > 0)].loc[idx[:, sup_tracts], :]
assoc_HC = grouped.loc[(grouped["group"] == 1) & (grouped["streamlines"] > 0)].loc[idx[:, assoc_tracts], :]
assoc_FEP = grouped.loc[(grouped["group"] == 2) & (grouped["streamlines"] > 0)].loc[idx[:, assoc_tracts], :]


ratio_HC = pd.DataFrame(index=sup_HC.groupby(level=0).mean().index).assign(
    FA=sup_HC.groupby(level=0).mean()["FA"]/assoc_HC.groupby(level=0).mean()["FA"],
    R1=sup_HC.groupby(level=0).mean()["R1"]/assoc_HC.groupby(level=0).mean()["R1"],
    streamlines=sup_HC.groupby(level=0).sum()["streamlines"]/assoc_HC.groupby(level=0).sum()["streamlines"]
)

ratio_FEP = pd.DataFrame(index=sup_FEP.groupby(level=0).mean().index).assign(
    FA=sup_FEP.groupby(level=0).mean()["FA"]/assoc_FEP.groupby(level=0).mean()["FA"],
    R1=sup_FEP.groupby(level=0).mean()["R1"]/assoc_FEP.groupby(level=0).mean()["R1"],
    streamlines=sup_FEP.groupby(level=0).sum()["streamlines"]/assoc_FEP.groupby(level=0).sum()["streamlines"]
)

ratio_means = pd.concat([
    ratio_HC.assign(group="HC"),
    ratio_FEP.assign(group="FEP")
])

aggregator = {
    "FA": "mean",
    "R1": "mean",
    "streamlines": "sum"
}

means = pd.concat([
    assoc_HC.groupby(level=0).agg(aggregator).assign(group="HC", tracts="association"),
    assoc_FEP.groupby(level=0).agg(aggregator).assign(group="FEP", tracts="association"),
    sup_HC.groupby(level=0).agg(aggregator).assign(group="HC", tracts="superficial"),
    sup_FEP.groupby(level=0).agg(aggregator).assign(group="FEP", tracts="superficial")
])



In [None]:
import plotly.express as px
import plotly.graph_objects as go

layout = go.Layout(
    margin = go.layout.Margin(l=10, r=10, b=10, t=50),
    font_size=14
)

fig = px.box(means, color="group", y="streamlines", x="tracts", points="all", width=1000, template="seaborn", title="Number of Streamlines", labels={
    "FA": "FA",
    "R1": "R1",
    "streamlines": "# streamlines"
})
fig.update_layout(layout)
fig.show()