# Set up

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('../src/'))
print(module_path)
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
%reload_ext autoreload
%autoreload 2
    
import glob
import random
import pickle
import scipy
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sns
import nibabel as nib
from tqdm.auto import tqdm
from collections import Counter
from itertools import combinations, permutations

import dipy
from dipy.segment.metric import mdf
from dipy.viz import window, actor

from data.FiberData import FiberData
from data.BundleData import BundleData
from data.data_util import *
from utils.general_util import *
from utils.plot_util import *
from utils.line_fit import *
from model.model import *
from evaluation import *
from inference import *

In [None]:
SEED = 2022
DEVICE_NUM = 5
set_seed(seed=SEED)
DEVICE = set_device()
if DEVICE == 'cuda':
    torch.cuda.set_device(DEVICE_NUM)
    print(torch.cuda.device_count(), 
          torch.cuda.current_device(),
          torch.cuda.get_device_name(DEVICE_NUM))

In [None]:
result_folder = "../results/"
model_folder = f"{result_folder}models/"
plot_folder = f"{result_folder}plots/"
result_data_folder = f"{result_folder}data/"
log_folder = f"{result_folder}logs/"
eval_folder = f"{result_folder}evals/"
data_files_folder = "../data_files/"

# CHANGE DATA FOLDER BELOW
data_folder = ""

# Load Training Data

Change the code below for selecting training subjects accordingly.

In [None]:
'''Load metadata with bundle count and streamline count'''
df_meta = pd.read_csv(data_files_folder + "metadata.csv")
print(df_meta.shape)
print(df_meta.DX.value_counts())

In [None]:
'''Select CN and sort by bundles then streamlines, Select subject to be trained on'''
df_tmp = df_meta.loc[df_meta.DX=='CN'].sort_values(by=['bundle_count','streamline_count'], 
                                          ascending=False)
n_subj = 10
subjs_train = df_tmp[:n_subj].Subject.values
print(len(subjs_train))

In [None]:
'''Get training data statistics'''
df_train = df_meta[df_meta.Subject.isin(subjs_train)]
print(df_train.Sex.value_counts())
df_train

In [None]:
%%time
args = {'n_points' : 256, 'n_lines' : None, 'min_lines' : 2, 
        'tracts_exclude' : ['CST_L_s', 'CST_R_s'],'preprocess' : '3d', 
        'rng' : None, 'verbose': False, 'data_folder' : data_folder}

data = FiberData(subjs_train, **args)
data.X.shape

# Load Inference Data

## Load Inference from Subject

In [None]:
def load_inference(subj_name, model_subfolder, 
                   result_data_folder, data_folder,
                   epoch, seed=0):
    
    data_args = {'n_points' : 256, 'n_lines' : None, 'min_lines' : 2, 
                'tracts_exclude' : ['CST_L_s', 'CST_R_s'], 'preprocess' : '3d', 
                'rng' : np.random.RandomState(seed), 'verbose': False, 
                'data_folder' : data_folder}

    subj = BundleData(subj_name, **data_args)    
    subj.load_inference_data(f"{result_data_folder}{model_subfolder}/E{epoch}_{subj_name}")
    return subj

In [None]:
zdim = 6
n_subj_load = 10
model_subfolder = f'convVAE3L_XUXU_Z{zdim}_B512_LR2E-04_WD1E-03_GCN2E+00_CN{n_subj_load}' 
epoch = 100
model, mean, std = load_model_for_inference(model_subfolder, model_folder, epoch, DEVICE)
print(mean, std)

msetting = parse_model_setting(model_subfolder)
msetting

In [None]:
key = "subj_train"
if key in msetting:
    if isinstance(msetting[key], float):
        suffix = f"_{key}{msetting[key]:.0E}" 
    elif isinstance(msetting[key], int):
        suffix = f"_{key}{msetting[key]}" 
    elif isinstance(msetting[key], str):
        suffix = f"_{msetting[key]}"
    else:
        suffix=""
suffix

In [None]:
'''Load inference data for single subject'''

subj_name_cn = 'example-CN-subj-name' # CN
subj_cn = load_inference(subj_name_cn, model_subfolder, 
                      result_data_folder, data_folder,
                      epoch, seed=SEED)

subj_name_ad = 'example-AD-subj-name' # AD
subj_ad = load_inference(subj_name_ad, model_subfolder, 
                      result_data_folder, data_folder,
                      epoch, seed=SEED)

# Embedding Plot

In [None]:
def make_subj_df(subj, method=None):
    if method:
        embeddings=getattr(subj, f'X_encoded_{method}')
        method=f"_{method}"
    else:
        embeddings=subj.X_encoded
        method=""
    df = pd.DataFrame(embeddings, columns=[f'Z{n+1}{method}' for n in range(embeddings.shape[1])])
    df['Bundle'] = map_list_with_dict(subj.y, subj.bundle_num)
    df['Hemisphere'] = label_hemispheres(df.Bundle.tolist())
    return df

In [None]:
method=None
df_cn = make_subj_df(subj_cn, method=method)
df_ad = make_subj_df(subj_ad, method=method)
df_cn

## Plot 2D Embeddings

In [None]:
def plot_embedding(df, zdim, cmap_b=None, cmap_h=None, 
                   annot="", method=None, suffix=None,
                   plot_folder=None):
    '''Plot 2D embedding (first 2 dimensions)'''

    fig, ax = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(22, 10))
    if cmap_b is None:
        cmap_b = make_color_map('nipy_spectral', set(df.Bundle),
                      plot_cmap=False)
    if cmap_h is None:
        cmap_h = make_color_map('nipy_spectral', set(df.Hemisphere),
                      plot_cmap=False)
        
    method = f"_{method}" if method else ""
    
    sns.scatterplot(x=f'Z1{method}', y=f'Z2{method}', hue='Bundle', 
                    palette=cmap_b, data=df, ax=ax[0], 
                    s=25, alpha=0.5)
    ax[0].set_xlabel(f"{annot} Z1{method}", fontsize=16)
    ax[0].set_ylabel(f"{annot} Z2{method}", fontsize=16)
    ax[0].legend(loc='upper center', bbox_to_anchor=(0.5, 1.16),
          fontsize=12, ncol=6, fancybox=True, shadow=True)
    
    sns.scatterplot(x=f'Z1{method}', y=f'Z2{method}', hue='Hemisphere', 
                    palette=cmap_h, data=df, ax=ax[1], 
                    s=25, alpha=0.5)
    ax[1].set_xlabel(f"{annot} Z1{method}", fontsize=16)
    ax[1].set_ylabel(f"{annot} Z2{method}", fontsize=16)
    ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, 1.1),
          fontsize=15, ncol=3, fancybox=True, shadow=True)

    if plot_folder:
        fig.tight_layout()
        fig.savefig(f'{plot_folder}{zdim}D_embeddings_{annot}{method}{suffix}.pdf', bbox_inches='tight')

Change `plot_folder` if saving plot to file

In [None]:
cmap_b = labeled_colormap('nipy_spectral', subj_cn.bundle_idx.keys(), plot_cmap=False)
cmap_h = labeled_colormap('jet', set(df_cn.Hemisphere), plot_cmap=False)
plot_embedding(df_cn, zdim, cmap_b=cmap_b, 
               cmap_h=cmap_h, method=method,
               annot="CN", suffix=suffix,
               plot_folder=None)
plot_embedding(df_ad, zdim, cmap_b=cmap_b, 
               cmap_h=cmap_h, method=method,
               annot="AD", suffix=suffix,
               plot_folder=None)

## Plot 3D Embeddings

In [None]:
fig = px.scatter_3d(df_cn, x='Z1', y='Z2', z='Z3',
                    color='Hemisphere', hover_name='Bundle',
                    color_discrete_map={
                        "Left": "red",
                        "Comm": "green",
                        "Right": "blue"},
                    width=700, height=600)
fig.update_traces(marker_size=2)
fig.show()

In [None]:
fig = px.scatter_3d(df_cn, x='Z1', y='Z2', z='Z3',
                    color='Bundle',
                    width=700, height=600)
fig.update_traces(marker_size=2)
fig.show()

# Streamline distance

## MDF - Euclidean 

In [None]:
'''Select sample from training data'''
n_sample = 300
x, z, x_recon = select_random_samples(model, data.X, device=DEVICE,
                                      n_sample=n_sample,
                                      mean=mean, std=std, seed=SEED)
zdim = z.shape[1]
print(x.shape, z.shape)

In [None]:
'''Randomly generate sample from embedding space'''
x_recon_rand, z_rand = generate_random_samples(model, zdim=zdim, n_sample=n_sample, 
                                     mean=mean, std=std, seed=SEED)
print(x_recon_rand.shape, z_rand.shape)

In [None]:
'''Calculate pairwise distances'''

dist_mdf=[]
dist_euc=[]
for i, j in combinations(range(len(z)), 2):
    dist_mdf.append(mdf(x[i], x[j]))
    dist_euc.append(np.linalg.norm(z[i]-z[j]))
    
dist_mdf_rand=[]
dist_euc_rand=[]
for i, j in combinations(range(len(z_rand)), 2):
    dist_mdf_rand.append(mdf(x_recon_rand[i], x_recon_rand[j]))
    dist_euc_rand.append(np.linalg.norm(z_rand[i]-z_rand[j]))

In [None]:
from scipy.stats import gaussian_kde

def plot_distance_corr(ax, x, y, zdim, title_text, fit_intercept=False):
    
    # fit gaussian kde
    xy = np.vstack([x,y])
    z = gaussian_kde(xy)(xy)
    print('Fitted Gaussian KDE...')
    
    # get correlation
    spearman, spval = scipy.stats.spearmanr(x, y)
    pearson, ppval = scipy.stats.pearsonr(x, y)
    
    # scatter plot
    cmap = truncate_colormap(plt.get_cmap('gray_r'), 0.3, 0.8)
    ax.scatter(x, y, c=z, cmap=cmap, s=1)

    # fitted line plot
    coeff, r2 = fit_line(x, y, fit_intercept=fit_intercept)
    xn = np.linspace(0, ax.get_xlim()[1], 500)
    yn = get_fitted_line(xn, coeff, fit_intercept=fit_intercept)
    ax.plot(xn, yn, '--', color='r')
    
    print(f"Spearman r={spearman:.3f}, p={spval:.4f}, " \
          f"Pearson r={pearson:.3f}, p={ppval:.4f}, " \
          f"R2={r2:.3f}")

    # plot label setting
    # plt.xlim(-0.5, 10)
    # plt.ylim(-10, 200)
    ax.set_title(f"Pairwise distance of {title_text} (N={n_sample})", fontsize=15)
    ax.set_xlabel(f"Euclidean Distance of {zdim}D Embeddings", fontsize=15)
    ax.set_ylabel(f"MDF Distance of Streamlines", fontsize=15)

    # Add correlation text
    corr_str = f"Spearman={spearman:.3f}\nPearson={pearson:.3f}\nR2={r2:.3f}"
    props = dict(boxstyle='square', facecolor="0.9", edgecolor='lightgrey', alpha=0.5)
    ax.text(0.05, 0.85, corr_str, transform=ax.transAxes, fontsize=15, verticalalignment='bottom', 
             horizontalalignment='left', bbox=props)

In [None]:
%%time

fit_intercept=False

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(18, 8))
plot_distance_corr(ax[0], dist_euc, dist_mdf, zdim, "selected training samples",
                   fit_intercept=fit_intercept)
plot_distance_corr(ax[1], dist_euc_rand, dist_mdf_rand, zdim, f"random {zdim}D samples",
                   fit_intercept=fit_intercept)
fig.savefig(f'{plot_folder}distance_corr_Z{zdim}{suffix}.png')

In [None]:
%%time

fit_intercept=False

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
plot_distance_corr(ax, dist_euc, dist_mdf, zdim, "selected training samples",
                   fit_intercept=fit_intercept)
fig.savefig(f'{plot_folder}distance_corr_sel_Z{zdim}{suffix}.png')

## Picking best z

In [None]:
'''Load evaluation metrics into variables below'''
val_z = []
val_spearman = []
val_pearson = []
val_r2 = []

In [None]:
plt.figure(figsize=(12, 7))
plt.plot(val_z, val_spearman, c='r',  marker='o', label='Spearman')
plt.plot(val_z, val_pearson, c='g', marker='o', label='Pearson')
plt.plot(val_z, val_r2, c='b', marker='o', label='R2')
plt.xlabel("Embedding Dimension", fontsize=16)
plt.ylabel("Metric Value", fontsize=16)
plt.grid()
plt.legend(fontsize=16)
plt.savefig(f'{plot_folder}nz-elbow-plot.pdf')

# Bundle Distance

In [None]:
def rearrange_square_matrix(A, new_idx):
    '''[UTIL] Rearrange square matrix based on new idx'''
    assert A.shape[0]==A.shape[1], "Not a square matrix"
    return np.array([[A[i][j] for j in new_idx] for i in new_idx])

def symmetric_lower_to_upper(A):
    '''[UTIL]'''
    return A + A.T - np.diag(np.diag(A))

In [None]:
subj=subj_cn
    
sorted_bundle_dict = sort_bundle_name_by_hemisphere(subj.bundle_idx.keys())
new_idx = map_list_with_dict(subj.bundle_idx.keys(), sorted_bundle_dict)

## MDF Distance (Bundle Centroid)

In [None]:
from dipy.segment.clustering import QuickBundles

BCentDist_mat = np.zeros((len(subj.bundle_idx), len(subj.bundle_idx)))

for i, bundle1 in enumerate(subj.bundle_idx.keys()):
    b1 = subj.X[subj.get_subj_bundle_idx(bundle1)]
    qb = QuickBundles(threshold=10., max_nb_clusters=1)
    clusters_b1 = qb.cluster(b1)
    
    for j, bundle2 in enumerate(subj.bundle_idx.keys()):
        if j > i:
            continue
        b2 = subj.X[subj.get_subj_bundle_idx(bundle2)]
        qb = QuickBundles(threshold=10., max_nb_clusters=1)
        clusters_b2 = qb.cluster(b2)
        
        BCentDist_mat[i, j] = mdf(clusters_b1.centroids[0],
                                   clusters_b2.centroids[0])

In [None]:
BCentDist_mat = symmetric_lower_to_upper(BCentDist_mat)
new_BCentDist_mat = rearrange_square_matrix(BCentDist_mat, new_idx)

fig, ax = plt.subplots(figsize=(10,10))

mask = np.triu(new_BCentDist_mat)
sns.heatmap(new_BCentDist_mat, square=True,
            xticklabels=sorted_bundle_dict.keys(), 
            yticklabels=sorted_bundle_dict.keys(), 
            mask=mask, 
            ax=ax)
ax.set_title("Pairwise MDF Distance of Bundle Centroids")

## Euclidean (Embeddding Centroid)

In [None]:
ECentDist_mat = np.zeros((len(subj.bundle_idx), len(subj.bundle_idx)))

for i, bundle1 in enumerate(subj.bundle_idx.keys()):
    b1 = subj.X[subj.get_subj_bundle_idx(bundle1)]
    cent1 = np.mean(b1, axis=0)
    for j, bundle2 in enumerate(subj.bundle_idx.keys()):
        if j > i:
            continue
        b2 = subj.X[subj.get_subj_bundle_idx(bundle2)]
        cent2 = np.mean(b2, axis=0)
        ECentDist_mat[i,j] = np.linalg.norm(cent1-cent2)

In [None]:
ECentDist_mat = symmetric_lower_to_upper(ECentDist_mat)
new_ECentDist_mat = rearrange_square_matrix(ECentDist_mat, new_idx)

fig, ax = plt.subplots(figsize=(10,10))

mask = np.triu(new_ECentDist_mat)
sns.heatmap(new_ECentDist_mat, square=True,
            xticklabels=sorted_bundle_dict.keys(), 
            yticklabels=sorted_bundle_dict.keys(), 
            mask=mask, 
            ax=ax)
ax.set_title("Pairwise Euclidean Distance of Bundle Embedding Centroids")

## Bundle Distance

In [None]:
from dipy.align.streamlinear import (StreamlineLinearRegistration,
                                     BundleMinDistanceMetric,
                                     BundleSumDistanceMatrixMetric,
                                     BundleMinDistanceAsymmetricMetric)

def bundle_min_distance(b1, b2):
    BMD = BundleMinDistanceMetric()
    BMD.setup(b1, b2)
    x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0])  # affine
    return BMD.distance(x0.tolist()) # distance in milimeters

In [None]:
%%time 

BDist_mat = np.zeros((len(subj.bundle_idx), len(subj.bundle_idx)))

for i, bundle1 in enumerate(subj.bundle_idx.keys()):
    b1 = subj.X[subj.get_subj_bundle_idx(bundle1)]
    
    for j, bundle2 in enumerate(subj.bundle_idx.keys()):
        if j > i:
            continue
        b2 = subj.X[subj.get_subj_bundle_idx(bundle2)]
        score = bundle_min_distance(b1, b2)
        print(i, j, score)
        BDist_mat[i, j] = score

In [None]:
BDist_mat = symmetric_lower_to_upper(BDist_mat)
new_BDist_mat = rearrange_square_matrix(BDist_mat, new_idx)

fig, ax = plt.subplots(figsize=(10,10))

mask = np.triu(new_BDist_mat)
sns.heatmap(new_BDist_mat, square=True,
            xticklabels=sorted_bundle_dict.keys(), 
            yticklabels=sorted_bundle_dict.keys(), 
            mask=mask, 
            ax=ax)
ax.set_title("Pairwise Bundle-based Minimum Distances")

This distance metric calculation might take a while, best to save it!

In [None]:
save_pickle(new_BDist_mat, f"{eval_folder}dist_bundle_bmd{suffix}")

## Wasserstein Distance

https://pythonot.github.io/quickstart.html#computing-wasserstein-distance

In [None]:
import ot
import ot.plot

In [None]:
%%time

WDist_mat = np.zeros((len(subj.bundle_idx), len(subj.bundle_idx)))

for i, bundle1 in enumerate(subj.bundle_idx.keys()):
    for j, bundle2 in enumerate(subj.bundle_idx.keys()):
        if j > i:
            continue
        b1 = subj.X_encoded[subj.get_subj_bundle_idx(bundle1)]
        b2 = subj.X_encoded[subj.get_subj_bundle_idx(bundle2)]
        a = np.ones((len(b1),)) / len(b1)
        b = np.ones((len(b2),)) / len(b2)
        M = ot.dist(b1, b2)
        wdist = ot.emd2(a, b, M, numItermax=5e+5)
        print(i, j, wdist)
        WDist_mat[i, j] = wdist

In [None]:
WDist_mat = symmetric_lower_to_upper(WDist_mat)
new_WDist_mat = rearrange_square_matrix(WDist_mat, new_idx)

fig, ax = plt.subplots(figsize=(10,10))

mask = np.zeros_like(new_WDist_mat, dtype=np.bool)
mask[np.triu_indices_from(mask)] = True
sns.heatmap(new_WDist_mat, square=True,
            xticklabels=sorted_bundle_dict.keys(), 
            yticklabels=sorted_bundle_dict.keys(), 
            mask=mask, 
            ax=ax)
ax.set_title("Pairwise Wassestein Distance of Bundles")

This distance metric calculation might take a while, best to save it!

In [None]:
save_pickle(new_WDist_mat, f"{eval_folder}dist_bundle_wasserstein{suffix}")

### Bundle-Specific Wasserstein distance

(aka Optimal Transport distance)
Use the functions below to inspect distance between 2 specfiic functions

In [None]:
def plot_ot_embeddings(b1, b2):
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))

    axes[0].scatter(b1[:, 0], b1[:, 1], c='r', s=3, label=bundle1)
    axes[0].scatter(b2[:, 0], b2[:, 1], c='b', s=3, label=bundle2)
    axes[0].legend(loc="upper right")
    axes[0].title.set_text('Bundle Embeddings')

    plt.sca(axes[1])
    ot.plot.plot2D_samples_mat(b1, b2, G0, color=[.5, .5, 1])
    axes[1].scatter(b1[:, 0], b1[:, 1], c='r', label=bundle1)
    axes[1].scatter(b2[:, 0], b2[:, 1], c='b', label=bundle2)
    axes[1].legend(loc="upper right")
    axes[1].title.set_text('OT matrix')
    
def plot_ot_matrix(M, G0):
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 8))

    axes[0].imshow(M, interpolation='nearest')
    axes[0].title.set_text('Cost Matrix')

    axes[1].imshow(G0, interpolation='nearest')
    axes[1].title.set_text('OT Matrix')

In [None]:
bundle1="CST_L"
bundle2="CST_R"

b1 = subj.X_encoded[subj.get_bundle_idx(bundle1)]
b2 = subj.X_encoded[subj.get_bundle_idx(bundle2)]
print(b1.shape, b2.shape)

a = np.ones((len(b1),)) / len(b1)
b = np.ones((len(b2),)) / len(b2)
M = ot.dist(b1, b2, metric='euclidean')
G0 = ot.emd(a, b, M)
W = ot.emd2(a, b, M)
print(W)

plot_ot_embeddings(b1, b2)
plot_ot_matrix(M, G0)

## Mantel Test

In [None]:
new_BDist_mat = load_pickle(f"{eval_folder}dist_bundle_bmd")
new_WDist_mat = load_pickle(f"{eval_folder}dist_bundle_wasserstein")

In [None]:
np.fill_diagonal(new_BCentDist_mat, 0)
np.fill_diagonal(new_ECentDist_mat, 0)
np.fill_diagonal(new_BDist_mat, 0)
np.fill_diagonal(new_WDist_mat, 0)

In [None]:
from skbio.stats.distance import mantel
coeff, p_value, _ = mantel(new_BCentDist_mat, new_ECentDist_mat, 
                           method='pearson', permutations=100)
print(coeff, p_value)
coeff, p_value, _ = mantel(new_BDist_mat, new_WDist_mat, 
                           method='pearson', permutations=100)
print(coeff, p_value)

## Plot Heatmaps

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(20,12))

mask = np.triu(new_BCentDist_mat)
cax1 = inset_axes(ax1,
                 width="70%",  # width: 40% of parent_bbox width
                 height="7%",  # height: 10% of parent_bbox height
                 loc='lower left',
                 bbox_to_anchor=(0.47, 1.15, 1, 1),
                 bbox_transform=ax.transAxes,
                 borderpad=0,
                 )
cax1.tick_params(labelsize=18)
sns.heatmap(new_BCentDist_mat, square=True,
            mask=mask,
            cbar_ax = cax1,
            cbar_kws={"orientation" : "horizontal", # "location":'top',
                      "shrink" : 0.5, "pad" : .01},
            ax=ax1)
ax1.set_xticklabels(sorted_bundle_dict.keys(), rotation=90, fontsize = 16)
ax1.set_yticklabels(sorted_bundle_dict.keys(), rotation=0, fontsize = 16)
ax1.set_title("Pairwise MDF Distance of\nBundle Centroids", fontsize=22)

mask = np.triu(new_ECentDist_mat)
cax2 = inset_axes(ax2,
                 width="70%",  # width: 40% of parent_bbox width
                 height="7%",  # height: 10% of parent_bbox height
                 loc='lower left',
                 bbox_to_anchor=(1.77, 1.15, 1, 1),
                 bbox_transform=ax.transAxes,
                 borderpad=0,
                 )
cax2.tick_params(labelsize=18)
sns.heatmap(new_ECentDist_mat, square=True,
            yticklabels=[], 
            mask=mask, 
            cbar_ax = cax2,
            cbar_kws={"orientation" : "horizontal", # "location":'top',
                      "shrink" : 0.5, "pad" : .01},
            ax=ax2)
ax2.set_xticklabels(sorted_bundle_dict.keys(), rotation=90, fontsize = 16)
ax2.set_title("Pairwise Euclidean Distance of\nBundle Embedding Centroids", fontsize=22)

plt.subplots_adjust(wspace = .07)
fig.savefig(f'{plot_folder}distance_centroid_Z{zdim}{suffix}.pdf')

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(20,12))

mask = np.triu(new_BDist_mat)
cax1 = inset_axes(ax1,
                 width="70%",  # width: 40% of parent_bbox width
                 height="7%",  # height: 10% of parent_bbox height
                 loc='lower left',
                 bbox_to_anchor=(0.47, 1.15, 1, 1),
                 bbox_transform=ax.transAxes,
                 borderpad=0,
                 )
cax1.tick_params(labelsize=18)
sns.heatmap(new_BDist_mat, square=True,
            mask=mask,
            cbar_ax = cax1,
            cbar_kws={"orientation" : "horizontal", # "location":'top',
                      "shrink" : 0.5, "pad" : .01},
            ax=ax1)
ax1.set_xticklabels(sorted_bundle_dict.keys(), rotation=90, fontsize = 16)
ax1.set_yticklabels(sorted_bundle_dict.keys(), rotation=0, fontsize = 16)
ax1.set_title("Pairwise BMD Distance of Bundles", fontsize=22)

mask = np.triu(new_WDist_mat)
cax2 = inset_axes(ax2,
                 width="70%",  # width: 40% of parent_bbox width
                 height="7%",  # height: 10% of parent_bbox height
                 loc='lower left',
                 bbox_to_anchor=(1.77, 1.15, 1, 1),
                 bbox_transform=ax.transAxes,
                 borderpad=0,
                 )
cax2.tick_params(labelsize=18)
sns.heatmap(new_WDist_mat, square=True,
            yticklabels=[], 
            mask=mask, 
            cbar_ax = cax2,
            cbar_kws={"orientation" : "horizontal", # "location":'top',
                      "shrink" : 0.5, "pad" : .01},
            ax=ax2)
ax2.set_xticklabels(sorted_bundle_dict.keys(), rotation=90, fontsize = 16)
ax2.set_title("Pairwise Wasserstein Distance of\nBundle Embeddings", fontsize=22)

plt.subplots_adjust(wspace = .07)
fig.savefig(f'{plot_folder}distance_bundle_Z{zdim}{suffix}.pdf')

# Interpolation

In [None]:
def sample_streamlines_from_bundle(subj, bundle_name, n_sample=1, seed=0):
    '''Select random samples from bundles in subject data'''
    bundle = subj.X[subj.get_bundle_idx(bundle_name)]
    return sample_idx(bundle, n_sample=n_sample, seed=seed)
    
def get_bundle_from_hemisphere(bnames, hemisphere='L'):
    '''[DATA-UTIL] Get all bundles from sample'''
    if hemisphere == 'L':
        bundles = [b for b in bnames if b.endswith('_L')]
    elif hemisphere =='R':
        bundles = [b for b in bnames if b.endswith('R')]
    elif hemisphere == 'C':
        bundles = [b for b in bnames if not b.endswith('_L') and not b.endswith('R')]
    else:
        print("Invalid hemisphere input, only accept L, R, C")
        return
    return bundles

In [None]:
'''Get index of all L/R/C bundles in subject'''
subj=subj_cn
bnames = subj.bundle_idx.keys()
left_cn = subj.get_subj_multibundle_idx(get_bundle_from_hemisphere(bnames, 'L'))
right_cn = subj.get_subj_multibundle_idx(get_bundle_from_hemisphere(bnames, 'R'))
comm_cn = subj.get_subj_multibundle_idx(get_bundle_from_hemisphere(bnames, 'C'))
print(left_cn.shape, right_cn.shape, comm_cn.shape)

In [None]:
'''Select sample from each hemisphere in subject'''
n_sample=2
sample_left = left_cn[sample_idx(left_cn, n_sample=n_sample, seed=SEED)]
sample_right = right_cn[sample_idx(right_cn, n_sample=n_sample, seed=SEED)]
sample_comm = comm_cn[sample_idx(comm_cn, n_sample=n_sample, seed=SEED)]

print([f"{get_bundle_for_idx(i, subj.bundle_idx)} @ {i}" for i in sample_left])
print([f"{get_bundle_for_idx(i, subj.bundle_idx)} @ {i}" for i in sample_right])
print([f"{get_bundle_for_idx(i, subj.bundle_idx)} @ {i}" for i in sample_comm])

In [None]:
def interpolate_np(p1, p2, n_points=10):
    '''[UTIL]'''
    return np.stack([np.linspace(i,j,n_points) for i,j in zip(p1,p2)],axis=1)

setting='l0r0'
X_enc_interp = interpolate_np(subj.X_encoded[sample_left[0]],
                              subj.X_encoded[sample_right[0]],
                              n_points=50)
X_recon_interp = decode_embeddings(X_enc_interp, model, 
                                   mean=mean, std=std)
print(X_recon_interp.shape)

In [None]:
'''visualize generated streamlines'''
palette = mpl._color_data.CSS4_COLORS
scene = window.Scene()
scene.set_camera(position=pos, 
                 focal_point=foc,
                 view_up=vup)

linecolor='white'
scene.add(actor.streamtube(np.expand_dims(X_recon_interp[0], axis=0), 
                           linewidth=0.8, colors=mpl.colors.to_rgb(palette[linecolor])))
scene.add(actor.streamtube(X_recon_interp[1:-1], 
                           linewidth=0.3, colors=mpl.colors.to_rgb(palette[linecolor])))
scene.add(actor.streamtube(np.expand_dims(X_recon_interp[-1], axis=0),
                           linewidth=0.8, colors=mpl.colors.to_rgb(palette[linecolor])))

window.record(scene, size=(1200, 1200), out_path=plot_folder+f"interp_{setting}.png")
window.show(scene, size=(1000,1000), reset_camera=False)

In [None]:
'''Get current camera setting'''
pos, foc, vup = scene.get_camera()

# Bundle Visualization

In [None]:
subj = subj_cn
label='cn'
bundles_vis = ['CST_L','CST_R']

color_list = mpl.cm.datad['tab10']['listed']
for i, name in enumerate(bundles_vis):
    plt.axhline(-i, linewidth=10, c=color_list[i], label=name)
plt.legend(bbox_to_anchor=(1, 1), loc='upper left', ncol=1,prop={'size': 15})

In [None]:
'''visualize streamlines'''
scene = window.Scene()
scene.SetBackground(1, 1, 1)

for i, bundle in enumerate(bundles_vis):
    lines = subj.X[subj.get_subj_bundle_idx(bundle)]
    scene.add(actor.line(lines, fake_tube=True, linewidth=6, colors=color_list[i]))
window.show(scene, size=(1000,1000), reset_camera=False)
# window.record(scene, size=(1200, 1200), out_path=f'{plot_folder}vis_X_3b.png')

In [None]:
'''Get current camera setting'''
pos, foc, vup = scene.get_camera()
print(pos, foc, vup)

In [None]:
'''visualize reconstruction streamlines'''
scene = window.Scene()
scene.SetBackground(1, 1, 1)

scene.set_camera(position=pos, 
                 focal_point=foc,
                 view_up=vup)

for i, bundle in enumerate(bundles_vis):
    lines = subj.X_recon[subj.get_subj_bundle_idx(bundle)]
    scene.add(actor.line(lines, fake_tube=True, linewidth=6, colors=color_list[i]))
window.show(scene, size=(1000,1000), reset_camera=False)
# window.record(scene, size=(1200, 1200), out_path=f'{plot_folder}vis_Xrecon_3b{suffix}.png')