In [None]:
from pathlib import Path
import numpy as np
import os
from dipy.io.streamline import load_tractogram
import nibabel as nib
import matplotlib.pyplot as plt
import re
import pandas as pd
from dipy.segment.clustering import QuickBundles
from dipy.segment.metric import AveragePointwiseEuclideanMetric, ResampleFeature
import dipy.stats.analysis as dsa

import dipy.tracking.streamline as dts

In [None]:
SLURM_TMPDIR = Path(os.environ.get("SLURM_TMPDIR"))
dwi_path = Path("/project/6033503/cfmm-bids/Palaniyappan/TOPSY/baseline/correct7T.v0.3.1/prepdwi_v0.0.12c/work/sub-001/dwi/uncorrected_denoise_unring_topup_eddy_regT1/dwi.nii.gz")
t1_path = Path('/scratch/knavynde/topsy/sub-001/anat/sub-001_space-orig_desc-preproc_T1w.nii.gz')
t1_map_path = Path('/scratch/knavynde/topsy/sub-001/anat/sub-001_acq-MP2RAGE_run-01_T1map.nii.gz')
fa_path = Path('/project/6033503/cfmm-bids/Palaniyappan/TOPSY/baseline/correct7T.v0.3.1/prepdwi_v0.0.12c/work/sub-001/dwi/uncorrected_denoise_unring_topup_eddy_regT1/dti_FA.nii.gz')
tracts_path = SLURM_TMPDIR/'tck-tracts'
clusters_left = SLURM_TMPDIR/'tck-clusters/left'
clusters_right = SLURM_TMPDIR/'tck-clusters/right'

assert tracts_path.exists() and clusters_left.exists() and clusters_right.exists()

In [None]:
t1map = nib.load(t1_map_path)

In [None]:
sup_f_l_path = tracts_path/"T_Sup-F_left.tck"
tracts = load_tractogram(str(rand_cluster), str(t1_path))

In [None]:
streamlines = tracts.streamlines
tracts_actor = actor.line(streamlines, cmap.line_colors(streamlines))

In [None]:
data = dwi.get_fdata()[..., 0]
mean, std = data[data > 0].mean(), data[data > 0].std()
value_range = (mean - 0.5 * std, mean + 1.5 * std)
slice_actor = actor.slicer(data, dwi.affine, value_range)

In [None]:
scene = window.Scene()
#slice_actor.display(z=25)
scene.add(tracts_actor)
#scene.add(slice_actor)
scene.reset_camera()
scene.elevation(130)
img = window.snapshot(scene, size=(1000,1000))
plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(img)

In [None]:
class TractProfile:
    def __init__(self, streamlines, ref):
        cluster = load_tractogram(str(streamlines), str(ref))
        if cluster is False:
            raise Exception(f"Cluster {streamlines} could not be loaded")

        if not cluster:
            self.streamlines = None
            return
        feature = ResampleFeature(nb_points=100)
        metric = AveragePointwiseEuclideanMetric(feature)

        qb = QuickBundles(np.inf, metric=metric)
        cluster_bundle = qb.cluster(cluster.streamlines)

        self.cluster_bundle = cluster_bundle
        self.streamlines = cluster.streamlines

        self.weights = dsa.gaussian_weights(self.streamlines)


    def get_profile(self, img):
        return dsa.afq_profile(
            img.get_fdata(),
            dts.orient_by_streamline(
                self.streamlines,
                self.cluster_bundle.centroids[0]
            ),
            img.affine,
            weights=self.weights
        )


def get_profiles_and_streamlines(paths, param_maps, ref):
    profiles = np.empty((len(param_maps), len(paths), 100))
    streamlines = np.empty(len(paths))

    for i, path in enumerate(paths):
        profile = TractProfile(path, ref)
        if profile.streamlines:
            streamlines[i] = len([*profile.streamlines])
            for j, param_map in enumerate(param_maps.values()):
                profiles[j, i] = profile.get_profile(param_map)
        else:
            streamlines[i] = 0
            profiles[:, i, :] = 0
    return profiles.mean(axis=2), streamlines


In [None]:
data = clusters_left
ref_img = t1_path

parameter_maps = {
    "FA": nib.load(fa_path)
}

if data.is_dir():
    paths = [*data.glob("*.tck")][:25]
else:
    raise FileNotFoundError("Input must be a directory")

profiles, streamlines = get_profiles_and_streamlines(paths, parameter_maps, ref_img)



In [None]:
cluster_numbers = (
    re.search(r'(?<=cluster_)\d{5}(?=\.tck$)', str(path))[0] for path in paths
)
profile_table = pd.DataFrame(
    {
        key: data for key, data in zip(parameter_maps, profiles)
    }
).assign(
    cluster=pd.Series(cluster_numbers),
    subject=1,
    streamlines=pd.Series(streamlines)
).set_index(["subject", "cluster"])

profile_table
with Path("test.csv").open('w') as f:
    profile_table.to_csv(f)

In [None]:
dataframes = (
    pd.read_csv(path, index_col=0) for path in ["test.csv", "test.csv"]
)

merged = pd.concat(dataframes)
with Path("all_test.csv").open('w') as f:
    merged.to_csv(f)

In [None]:
import dipy.stats.analysis as dsa

weighted_cluster = dsa.gaussian_weights(oriented_cluster)

In [None]:
fa = nib.load(fa_path)

profile = TractProfile(clusters_left/"cluster_00733.tck", t1_path).get_profile(fa)

In [None]:
plt.figure()
axes = plt.axes()
axes.set_ylabel("Fractional anisotropy")
axes.set_xlabel("Node along CST")
plt.plot(profile)
