# Set up

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('../src/'))
print(module_path)
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
%reload_ext autoreload
%autoreload 2
    
import re
import glob
import random
import pickle
import scipy
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sns
from tqdm.auto import tqdm
from collections import Counter, defaultdict
from itertools import combinations, permutations
from sklearn import linear_model
from sklearn.metrics import mean_absolute_error
from scipy.stats import ttest_ind, levene
from statsmodels.stats.multitest import multipletests

import dipy
from dipy.segment.metric import mdf
from dipy.viz import window, actor
from dipy.stats.analysis import assignment_map
from nibabel.streamlines.array_sequence import ArraySequence

from data.FiberData import FiberData
from data.BundleData import BundleData
from data.data_util import *
from utils.general_util import *
from utils.plot_util import *
from utils.line_fit import *
from model.model import *
from evaluation import *
from inference import *

In [None]:
SEED = 2022
DEVICE_NUM = 5
set_seed(seed=SEED)
DEVICE = set_device()
if DEVICE == 'cuda':
    torch.cuda.set_device(DEVICE_NUM)
    print(torch.cuda.device_count(), 
          torch.cuda.current_device(),
          torch.cuda.get_device_name(DEVICE_NUM))

In [None]:
result_folder = "../results/"
model_folder = f"{result_folder}models/"
plot_folder = f"{result_folder}plots/"
result_data_folder = f"{result_folder}data/"
log_folder = f"{result_folder}logs/"
eval_folder = f"{result_folder}evals/"
data_files_folder = "../data_files/"

# CHANGE DATA FOLDER BELOW
data_folder = ""

# Load Metadata

Change this section accordingly if your dataset have metadata

In [None]:
bundle_names = load_pickle(f"{data_files_folder}bundle_names")

In [None]:
'''Load metadata with bundle count and streamline count'''
df_meta = pd.read_csv(data_files_folder + "metadata.csv").sort_values('Subject')
print(df_meta.shape)
print(df_meta.DX.value_counts())

# Load Inference Data

In [None]:
def load_inference(subj_name, model_subfolder, 
                   epoch, seed=SEED):
    
    data_args = {'n_points' : 256, 'n_lines' : None, 'min_lines' : 2, 
                'tracts_exclude' : ['CST_L_s', 'CST_R_s'], 'preprocess' : '3d', 
                'rng' : np.random.RandomState(seed), 'verbose': False, 
                'data_folder' : data_folder, 
                'align_bundles_path' : f"{data_files_folder}bundle_centroid"}

    subj = BundleData(subj_name, **data_args)    
    subj.load_inference_data(f"{result_data_folder}{model_subfolder}/E{epoch}_{subj_name}")
    return subj

In [None]:
'''Load Model'''
zdim = 6
model_subfolder = f'convVAE3L_XUXU_Z{zdim}_B512_LR2E-04_WD1E-03_GCN2E+00_CN10'
epoch = 100
model, mean, std = load_model_for_inference(model_subfolder, model_folder, epoch, DEVICE)
print(mean, std)
msetting = parse_model_setting(model_subfolder)
msetting

In [None]:
key = "subj_train"
# suffix=f"_{key}{msetting[key]:.0E}" if key in msetting else ""
if key in msetting:
    if isinstance(msetting[key], float):
        suffix = f"_{key}{msetting[key]:.0E}" 
    elif isinstance(msetting[key], int):
        suffix = f"_{key}{msetting[key]}" 
    elif isinstance(msetting[key], str):
        suffix = f"_{msetting[key]}"
    else:
        suffix=""
suffix

In [None]:
'''Load inference data for single subject'''

subj_name_cn = 'example-CN-subj-name' # CN
subj_cn = load_inference(subj_name_cn, model_subfolder, 
                      result_data_folder, data_folder,
                      epoch, seed=SEED)

subj_name_ad = 'example-AD-subj-name' # AD
subj_ad = load_inference(subj_name_ad, model_subfolder, 
                      result_data_folder, data_folder,
                      epoch, seed=SEED)

# Anomaly Detection

Using MAE as metric

## Bundle Specific

### MAE (Reconstruction)

In [None]:
def compute_mae(x, xb):
    x = x.reshape((x.shape[0],-1))
    xb = xb.reshape((xb.shape[0], -1))
    return mean_absolute_error(x, xb)

def get_bundle_mae(model_subfolder, eval_folder, epoch, save_df=True, suffix=None):
        
    if do_filter:
        suffix = suffix+"_filtered"
        
    if not save_df:
        return pd.read_csv(f"{eval_folder}mae_bundle{suffix}.csv")
    
    ls = [] # list of [subj, bundle, mae, count]
    for i, fname in enumerate(list(glob.glob(f"{result_data_folder}{model_subfolder}/E{epoch}*.pkl"))):
        subj_name = re.search(f"(?<=E{epoch}_)(.*)(?=.pkl)", fname).group(1)
        print(f"{i+1}: {subj_name}")
        subj = load_inference(subj_name, model_subfolder, epoch)
        
        
        for bName in subj.bundle_idx.keys():
            bIdx = subj.get_subj_bundle_idx(bName)
            mae = compute_mae(subj.X[bIdx], subj.X_recon[bIdx])
            ls.append([subj_name, bName, mae, len(bIdx)])
    
        ls.append([subj_name,'all', 
                   compute_mae(subj.X, subj.X_recon), 
                   subj.X.shape[0]])
      
    df = pd.DataFrame(ls, columns=['Subject','Bundle','MAE','Count'])
    
    df.to_csv(f"{eval_folder}mae_bundle{suffix}.csv", index=False) 
    return df

When running for the first time, set `save_df` to `True` to save the MAE to file. 
Set this variable to `False` to load it from file.

In [None]:
%%time
df_mae_all = get_bundle_mae(model_subfolder, eval_folder, epoch, 
                            save_df=False, suffix=suffix)

df_mae = df_mae_all.loc[~df_mae_all.Subject.isin(subjs_train)]
print(df_mae.shape, len(df_mae.Subject.unique()))
df_mae.head()

### TTest + FDR

In [None]:
def regress_age_sex(df, df_meta, cols=['MAE']):
    df = df.merge(df_meta[['Subject','Sex','Age']], on='Subject')
    regr = linear_model.LinearRegression()
    dummies = pd.get_dummies(df.Sex)
    regress = pd.DataFrame()
    regress['Age'] = df['Age']
    regress['Sex'] = dummies.M
    
    x = regress.values
    y = np.mean(df[cols], axis=1).values
    regr.fit(x, y)
    y_pred = regr.predict(x)
    
    corrected = df[cols].sub(y_pred, axis=0) \
                        .add(np.mean(df[cols], axis=1), axis=0)
    df.loc[:, cols]=corrected
    return df.drop(columns=['Age','Sex'])

In [None]:
df_mae = regress_age_sex(df_mae, df_meta, cols=['MAE'])

In [None]:
p_ls = []
for bName in df_mae.Bundle.unique():
    df_b = df_mae.loc[df_mae.Bundle==bName] \
            .merge(df_meta[['Subject', 'DX']], on='Subject') \
            .drop(columns=['Subject','Bundle','Count'])
    t_cn = df_b.loc[df_b.DX=='CN'].MAE.values
    t_mci = df_b.loc[df_b.DX=='MCI'].MAE.values
    t_ad = df_b.loc[df_b.DX=='Dementia'].MAE.values

    p_ls.append([bName, 
                 ttest_ind(t_cn, t_mci, equal_var=True).pvalue, 
                 ttest_ind(t_cn, t_ad, equal_var=True).pvalue,
                 ttest_ind(t_mci, t_ad, equal_var=True).pvalue])
df_bundle_pval = pd.DataFrame(p_ls, columns=['Bundle','MCI','Dementia','MCI-AD']) \
                    .sort_values('Bundle', ascending=True)
df_bundle_pval.head()

In [None]:
df_bundle_pval['MCI_FDR']=multipletests(df_bundle_pval['MCI'], method='fdr_bh')[1]
df_bundle_pval['Dementia_FDR']=multipletests(df_bundle_pval['Dementia'], method='fdr_bh')[1]
df_bundle_pval['MCI-AD_FDR']=multipletests(df_bundle_pval['MCI-AD'], method='fdr_bh')[1]

In [None]:
fdr=True

fdr_suffix="_FDR" if fdr else ""

df_plot = df_bundle_pval[[f'MCI{fdr_suffix}',f'Dementia{fdr_suffix}']]
d = {True: "*", False: ""}
df_rej = df_plot<0.05
df_rej[f'MCI{fdr_suffix}'] = df_rej[f'MCI{fdr_suffix}'].map(d)
df_rej[f'Dementia{fdr_suffix}'] = df_rej[f'Dementia{fdr_suffix}'].map(d)
# df_rej[f'MCI-AD{fdr_suffix}'] = df_rej[f'MCI-AD{fdr_suffix}'].map(d)

fig, ax = plt.subplots(figsize=(20,2))
sns.heatmap(df_plot.T, 
            cmap='Blues_r',linewidth=0.1,
            annot=df_rej.T, fmt='s', annot_kws={"size": 14}, 
            ax=ax)
ax.set_xticklabels(df_bundle_pval.Bundle, rotation=90)
fig.show()
df_rej['Bundle']=df_bundle_pval['Bundle']

In [None]:
ad_sig_bundles = df_rej.loc[df_rej[f'Dementia{fdr_suffix}']=='*'].Bundle.values
ad_sig_bundles

### Plot Bundle Anomaly

In [None]:
def compute_wa_bundle(df, key, metric='MAE'):
    '''
        Get weight average of MAE with line count as weight
        key specify column name of result
    '''
    df["WSum"] = df[metric] * df["Count"]
    df_wa = df.groupby('Bundle').sum()
    df_wa[key]=df_wa['WSum']/df_wa['Count']
    return df_wa[[key]].sort_values(key, ascending=False)

def get_wa_bundle_by_dx(df_anom, df_meta, metric='MAE'):
    '''Get weight average of MAE of bundle for each diagnosis group'''
    df = df_anom[['Bundle']].drop_duplicates()
    for dx in set(df_meta.DX):
        subjs_idx = df_meta.loc[df_meta.DX==dx].Subject.values
        df_dx = df_anom.loc[df_anom.Subject.isin(subjs_idx)].copy()
        df_wa = compute_wa_bundle(df_dx, dx, metric=metric)
        df = df.merge(df_wa, on='Bundle')
    return df

In [None]:
metric='MAE'
df_wa = get_wa_bundle_by_dx(df_mae, df_meta, metric=metric).sort_values('CN', ascending=False)
print(df_wa.shape)
df_wa.head()

In [None]:
ax = df_wa.plot.bar(x='Bundle', y=['CN','MCI','Dementia'], 
                    color={'CN':'g','MCI':'b','Dementia':'r'},
                    width=0.6, rot=90,
                    fontsize=12, figsize=(18,8))
ax.set_xlabel('Bundle', fontsize=15)
ax.set_ylabel(metric, fontsize=15)
ax.legend(fontsize=12).set_title('Group',prop={'size':15})

In [None]:
df_wa['MCI_Diff'] = df_wa['MCI']-df_wa['CN']
df_wa['Dementia_Diff'] = df_wa['Dementia']-df_wa['CN']
df_wa = df_wa.sort_values('Dementia_Diff', ascending=True)
df_wa = df_wa.merge(df_rej, on='Bundle')

In [None]:
df_plot=df_wa.merge(df_bundle_pval[['Bundle','MCI_FDR','Dementia_FDR']], on='Bundle', suffixes=('','_pval'))
df_plot.MCI_FDR_pval = df_plot.MCI_FDR_pval.round(decimals=5)
df_plot.Dementia_FDR_pval = df_plot.Dementia_FDR_pval.round(decimals=5)
df_plot.loc[df_plot.Dementia_FDR=='', 'Dementia_FDR_pval']= ""
df_plot.loc[df_plot.MCI_FDR=='', 'MCI_FDR_pval']= ""

In [None]:
ax=df_plot.plot.barh(x='Bundle', y=['MCI_Diff','Dementia_Diff'], width=0.9,
          color={'MCI_Diff':'b','Dementia_Diff':'r'}, align='center',
          fontsize=14, figsize=(20,12))
ax2 = ax.twinx()

ax.set_xlabel("MAE", fontsize=16)
ax.yaxis.label.set_visible(False)
ax.legend(["MCI-CN",'AD-CN'], prop={'size': 16})

# get first axis tick labels
bundle_names['all'] = 'all_bundles'
axlabels = [item.get_text() for item in ax.get_yticklabels()]
ax2labels = map_list_with_dict(axlabels, bundle_names)

# get absolute tick position
y_min, y_max = ax.get_ylim()
tickpos = [(tick - y_min)/(y_max - y_min) for tick in ax.get_yticks()]
# set secondary yticks positions and labels
ax2.set_yticks(tickpos)
ax2.set_yticklabels(ax2labels, fontsize=14)

for i, container in enumerate(ax.containers):
    colname = container.get_label().split('_')[0]+fdr_suffix
    ax.bar_label(container, labels=df_plot[colname],
                 padding=2, fontsize=14)
plt.tight_layout()
plt.savefig(f"{plot_folder}mae_bundle.pdf")

## Segment MAE

Download atlas dataset [here](https://figshare.com/articles/dataset/Atlas_of_30_Human_Brain_Bundles_in_MNI_space/12089652) for creating bundle segments.

In [None]:
# CHANGE THIS FOLDER NAME ACCORDINGLY
atlas_data_folder = ""

def parse_atlas_tract_name(fname):
    return fname.split(".")[0]
    
args = {'n_points' : 256, 'n_lines' : None, 'min_lines' : 2, 
        'tracts_exclude' : ['CST_L_s', 'CST_R_s'],'preprocess' : '3d', 
        'rng' : None, 'verbose': False, 'parse_tract_func' : parse_atlas_tract_name,
        'data_folder' : atlas_data_folder, 'sub_folder_path' : "bundles/"}

atlas_data = BundleData("", **args)

### Compute Segment MAE

In [None]:
def compute_segment_mae(subj, bname, bIdx=None, n_segments=100):
    
    model_bundle = atlas_data.X[atlas_data.get_subj_bundle_idx(bname)]
    bundle = subj.X[subj.get_subj_bundle_idx(bname)]
    bundle_recon = subj.X_recon[subj.get_subj_bundle_idx(bname)]
    if bIdx is not None:
        bundle = bundle[bIdx]
        bundle_recon = bundle_recon[bIdx]
    indx = assignment_map(ArraySequence(bundle), 
                          ArraySequence(model_bundle),
                          n_segments)
    
    ls = []
    for i in range(n_segments):
        segment = bundle.reshape(-1, 3)[np.where(indx==i)]
        segment_recon = bundle_recon.reshape(-1, 3)[np.where(indx==i)]
        if len(segment)==0 and len(segment_recon)==0:
            ls.append(0)
        else:
            ls.append(mean_absolute_error(segment, segment_recon))
    return ls

def save_segment_mae(n_segments=100, suffix=None):
    '''
        [RUN ONCE]
        Save to file MAE at each position in line for each bundle and subject
        Each bundle is saved to its own sheet in the excel file
    '''
    if do_filter and not suffix.endswith("filtered"):
        suffix = suffix+"_filtered"
        
    result_dict = defaultdict(list)

    for i, fname in enumerate(list(glob.glob(f"{result_data_folder}{model_subfolder}/E{epoch}*.pkl"))):
        subj_name = re.search(f"(?<=E{epoch}_)(.*)(?=.pkl)", fname).group(1)
        print(f"{i+1}: {subj_name}")

        subj = load_inference(subj_name, model_subfolder, epoch)

        for bName in subj.bundle_idx.keys():
            bIdx = subj.get_subj_bundle_idx(bName)
            entry = [subj_name, len(bIdx)]
            entry.extend(compute_segment_mae(subj, bName, 
                                             bIdx=d_filtered[bName], n_segments=n_segments))
            result_dict[bName].append(entry)
    
    # Save to file
    names=['Subject','Count']
    names.extend([i for i in range(n_segments)])
    with pd.ExcelWriter(f"{eval_folder}mae_segment_S{n_segments}{suffix}.xlsx") as writer:  
        for bundle, value in result_dict.items():
            result = pd.DataFrame(value, columns=names)
            result.to_excel(writer, sheet_name=bundle, index=0)  

In [None]:
# RUN ONCE
save_segment_mae(suffix=suffix)

### Plot along-tract MAE

In [None]:
'''Functions for loading MAE at line position and computing WA for'''

def load_segment_mae(bName, n_segments=100, select=None, suffix=None):
    '''Load from file MAE for each position in line for bundle'''
    df = pd.read_excel(open(f"{eval_folder}mae_segment_S{n_segments}{suffix}.xlsx", 'rb'), 
                  sheet_name=bName, index_col=None)
    df['Subject']=df['Subject'].astype(str)
    if select is not None:
        return df[df.Subject.isin(select)]
    return df

def compute_mae_wa_segment(df, n_segments=100):
    return (df[range(n_segments)].mul(df['Count'], axis=0).sum(axis=0)/df.Count.sum()).values

def get_mae_wa_segment_dx(bName, df_meta, n_segments=100, 
                          exclude=None, suffix=None, keep_subj=False):
    '''Get weight average of MAE at position for each diagnosis group'''
    df = pd.DataFrame()
    dict_dx={}
    for dx in set(df_meta.DX):
        subjs = df_meta.loc[df_meta.DX==dx].Subject.values
        if exclude is not None:
            subjs = set(subjs)-set(exclude)
        df_mae_sel = load_segment_mae(bName, n_segments=n_segments, 
                                      select=subjs, suffix=suffix)
        df_mae_sel = regress_age_sex(df_mae_sel, df_meta, cols=range(100))
        if keep_subj:
            dict_dx[dx] = df_mae_sel
        else:
            dict_dx[dx] = df_mae_sel.drop(columns=['Count','Subject']).T
        df[dx] = compute_mae_wa_segment(df_mae_sel, n_segments=n_segments)
    return df, dict_dx

In [None]:
bundle='AF_L'
df_wa, dict_dx = get_mae_wa_segment_dx(bundle, df_meta, exclude=subjs_train, suffix=suffix)

In [None]:
'''Plot along tract MAE with 95% CI'''
fig, ax = plt.subplots(figsize=(13,8))
d_color = {'CN':'g', 'MCI':'b', 'Dementia':'r'}
for dx, df in dict_dx.items():
    df = dict_dx[dx].reset_index().melt(id_vars='index', value_name = 'MAE') \
                    .drop(columns=['variable'])
    sns.lineplot(data=df, x='index',y='MAE', 
                 color=d_color[dx], label=dx,
                 alpha=0.6, 
                 ci=95,
                 ax=ax)

ax.set_ylabel('MAE', fontsize=16)
ax.set_xlabel(f'Segment Number', fontsize=16)
ax.set_ylim([0, 4])

l = ax.legend(title=f"{bundle}", loc='upper center', 
#             bbox_to_anchor=(0.5, 1.12), 
              fontsize=16, title_fontsize=16,
              ncol=3, fancybox=True, shadow=False)

### T-test + FDR

In [None]:
def ttest_along_tract_segment_fdr(bName, n1='CN', n2='AD', n_segments=100, suffix=None):
    df_wa, dict_dx = get_mae_wa_segment_dx(bName, df_meta, 
                                      exclude=subjs_train,
                                      suffix=suffix)
    pvals = []
    for i in range(n_segments):
        t1 = dict_dx[n1].iloc[i]
        t2 = dict_dx[n2].iloc[i]

        pvals.append(ttest_ind(t1, t2, equal_var=True).pvalue)
    df_wa[f'pval_{n1}_{n2}']=pvals
    df_wa['Bundle']=bName
    return df_wa

def ttest_multi_tract_fdr(bNames, n1='CN', n2='AD', n_segments=100, suffix=None):
    df_ls = []
    for bName in tqdm(bNames):
        df_ls.append(ttest_along_tract_segment_fdr(bName, n1, n2, 
                                                   n_segments=n_segments, suffix=suffix))
    df = pd.concat(df_ls, ignore_index=True)
    fdr = multipletests(df[f'pval_{n1}_{n2}'], method='fdr_bh')
    df[f'rej_fdr']=fdr[0]
    df[f'pval_fdr'] = fdr[1]
    return df

bundle_names = load_pickle(f"{data_files_folder}bundle_names")
df_test = ttest_multi_tract_fdr(ad_sig_bundles, n1='CN',n2='Dementia', suffix=suffix)
df_test.head()

In [None]:
df_test['bundle_rej_fdr']=df_test.groupby('Bundle')[f'pval_CN_Dementia'] \
                                .transform(lambda s: multipletests(s,method='fdr_bh')[0])
df_test['bundle_pval_fdr']=df_test.groupby('Bundle')[f'pval_CN_Dementia'] \
                                .transform(lambda s: multipletests(s, method='fdr_bh')[1])
df_test.head()

In [None]:
df_test.loc[df_test.rej_fdr].Bundle.value_counts()

### Visualize

In [None]:
def generate_disk_colors(indx, n, colors, seed=SEED):
    disks_color = []
    for i in range(len(indx)):
        disks_color.append(tuple(colors[indx[i]]))
    return disks_color

In [None]:
bundle='CCMid'
df_wa, dict_dx = get_mae_wa_segment_dx(bundle, df_meta, 
                                   exclude=subjs_train, suffix=suffix)
df_test_bundle = df_test.loc[df_test.Bundle==bundle]
df_test_bundle=df_test_bundle.reset_index()

In [None]:
pvals = df_test_bundle.pval_fdr.values
logp = -np.log10(pvals)

fig, ax = plt.subplots(figsize=(12,8))

mask_sig = np.where(pvals<0.05)
mask_nonsig = np.where(pvals>=0.05)
seg_idx = np.arange(0,100,1)

plt.bar(seg_idx[mask_nonsig], logp[mask_nonsig], color='b', alpha=0.7)
plt.bar(seg_idx[mask_sig], logp[mask_sig], color='r', alpha=0.7)
plt.axhline(-np.log10(0.05), c='grey', linestyle='dashed', 
            label='p=0.05')

ax.set_ylim([0,2])
ax.set_xlabel('Segment Number', fontsize=16)
ax.set_ylabel("-log10(pval)", fontsize=16)
ax.legend(title=bundle, title_fontsize=16, fontsize=16)

n_sig = df_test.loc[df_test.Bundle==bundle].rej_fdr.value_counts().loc[True]
l = ax.legend(title=f"{bundle}\n N(significant segments={n_sig}", loc='upper right', 
#             bbox_to_anchor=(0.5, 1.12), 
              fontsize=16, title_fontsize=16,
              ncol=3, fancybox=True, shadow=False)
plt.setp(l.get_title(), multialignment='center')

In [None]:
lines = atlas_data.X[atlas_data.get_subj_bundle_idx(bundle)]
print(lines.shape)
indx = assignment_map(ArraySequence(lines), 
                      ArraySequence(lines), 100)

In [None]:
'''Make colormap based on p-value of CN vs. AD'''

color_dict={True : [1,0,0], False : [0,0,1]}
# plt.scatter(range(100), df_test_bundle.pval_fdr, s=5, 
#             c=map_list_with_dict(df_test_bundle.rej_fdr, color_dict))

colors = map_list_with_dict(df_test_bundle.rej_fdr, color_dict)
colors = [tuple(i) for i in list(colors)]
colors = generate_disk_colors(indx, 100, colors)
len(colors)

In [None]:
_,_ = plt.subplots(figsize=(10, 1))
legend_colors = [[1,0,0], [0,0,1]]
f = lambda m,c: plt.plot([],[],marker=m, color=c, ls="none")[0]
handles = [f("s", legend_colors[i]) for i in range(2)]
labels = ['Significant', 'Non-significant']
legend = plt.legend(handles, labels, fontsize=16, loc=3, ncol=2, framealpha=1, frameon=True)

def export_legend(legend, filename="legend.png", expand=[-5,-5,5,5]):
    fig  = legend.figure
    fig.canvas.draw()
    bbox  = legend.get_window_extent()
    bbox = bbox.from_extents(*(bbox.extents + np.array(expand)))
    bbox = bbox.transformed(fig.dpi_scale_trans.inverted())
    fig.savefig(f"{plot_folder}{filename}", dpi="figure", bbox_inches=bbox)

export_legend(legend)
plt.show()

In [None]:
'''visualize streamlines based on t-test results'''
scene = window.Scene()
scene.SetBackground(1, 1, 1)

# scene.set_camera(position=pos, 
#                  focal_point=foc,
#                  view_up=vup)

scene.add(actor.line(lines, fake_tube=True, linewidth=6, colors=colors))
window.show(scene, size=(1000,1000), reset_camera=False)
window.record(scene, size=(1200, 1200), out_path=f'{plot_folder}vis_{bundle}_ttest_seg{suffix}.png')
pos, foc, vup = scene.get_camera()

In [None]:
'''Make colormap based on MAE'''
# cmap = plt.get_cmap('gray_r')(np.linspace(0, 1, 256))[:,:3]

values = df_test_bundle.Dementia
norm = mpl.colors.Normalize(vmin=0, 
                            vmax=3)
cmap = plt.get_cmap('jet')
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
colors = cmap.to_rgba(values)[:,:3]
colors = [tuple(i) for i in list(colors)]
print(len(colors))

fig, ax = plt.subplots(figsize=(6, 1))
fig.subplots_adjust(bottom=0.5)
fig.colorbar(cmap, cax=ax, orientation='horizontal', label='MAE')

colors = generate_disk_colors(indx, 100, colors)
len(colors)

In [None]:
'''visualize streamlines based on MAE'''
scene = window.Scene()
scene.SetBackground(1, 1, 1)

scene.set_camera(position=pos, 
                 focal_point=foc,
                 view_up=vup)

scene.add(actor.line(lines, fake_tube=True, linewidth=6, colors=colors))
window.show(scene, size=(1000,1000), reset_camera=False)
window.record(scene, size=(1200, 1200), out_path=f'{plot_folder}vis_{bundle}_mae_seg{suffix}.png')

In [None]:
pos, foc, vup = scene.get_camera()