# NTFA Visualization

1. Spatial factor plot
2. Embedding plots
3. Voxel Noise
4. Reconstruction
5. Embedding distance and Activity distance RSA

Yiyu 2022/06

In [None]:
# check we're in our env (*)
%conda env list

In [None]:
# path for the NTFA package
NTFA_path = "/work/abslab/NTFA_packages/NTFADegeneracy/"

import sys
sys.path.append(NTFA_path)
import htfa_torch.dtfa as DTFA
import htfa_torch.niidb as niidb
import htfa_torch.utils as utils
import htfa_torch.tardb as tardb
import logging
import numpy as np
import pandas as pd
import glob
import os
import webdataset as wds
import torch
import itertools
from ordered_set import OrderedSet

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import plotly.express as px
from nilearn import plotting, image
from scipy.spatial.distance import pdist, squareform
from scipy.stats import spearmanr

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

from datetime import datetime
datetime.now().strftime('%Y-%m-%d %H:%M:%S')

In [None]:
from IPython.display import set_matplotlib_formats
%matplotlib inline
set_matplotlib_formats('pdf', 'svg')

In [None]:
nifti_dir = '/work/abslab/AVFP/denoised/'
logfiles_dir = '/work/abslab/AVFP/logfiles/AffVidsNovel_logfiles/'

mask_dir = '/home/wang.yiyu/AVFP/masks/'
base_dir = '/work/abslab/Yiyu/NTFA_AVFP/'

In [None]:
# **** parameters that define the model directory ******
subs = 20 #'All' #note, database file must have been created already for these subjects


included_data = pd.read_csv(base_dir + 'fmri_info/included_avfp_subjects.csv', header=None)
subIDs = included_data[0].astype('str').tolist()
print(subIDs)
total_subs = len(subIDs)
print(f"total subs = {total_subs}")

# using GM (and SNR) or SNR only?
mask_type = 'GM' #'GMandSNR' #'SNR' or 'GM'

# penalty weights (participant, stimulus, combination)
p_weight, s_weight, c_weight = 1, 1, 1
linear_opts = 'None' # 'C', 'PSC' 'None'
# additional parameters:
n_epoch = 1000
n_factor = 100
n_check = 50 # save checkpoints every n_check epochs for model



# condition:
condition = '_HeightsOnly'


# noise model:
noise_model = True
voxel_noise = 0.3
noise_opts = f'learned-{voxel_noise}' 

# noise_opts = f'fixed-{voxel_noise}'

In [None]:
# define database filename
if subs != 'All':
    #AVFP_FILE = 'data/AVFP_NTFA_memory_N' + str(n_files) + '_subsetN' + str(subs) + '_' + mask_type + 'mask.tar'
    AVFP_FILE = base_dir + f'data/downsampled_test/AVFP_NTFA_N{total_subs}_subsetN{subs}_{mask_type}mask{condition}.tar'
else: #including all subjects
    AVFP_FILE = base_dir + f'data/AVFP_NTFA_N{total_subs}_{subs}_{mask_type}mask{condition}.tar'
print('\nFetching database:',AVFP_FILE)

avfp_db = tardb.FmriTarDataset(AVFP_FILE)

In [None]:
# set up directory and filename for saving model

if noise_model:
    query_dir = f'models/ablation_comparison/AVFP_NTFA_sub-{subs}_epoch-{n_epoch}_factor-{n_factor}_mask-{mask_type}_{p_weight}{s_weight}{c_weight}_lin-{linear_opts}_noise-{noise_opts}/'
else:
    query_dir = f'models/ablation_comparison/AVFP_NTFA_sub-{subs}_epoch-{n_epoch}_factor-{n_factor}_mask-{mask_type}_{p_weight}{s_weight}{c_weight}_lin-{linear_opts}/'

if not os.path.isdir(query_dir):
    os.makedirs(query_dir)
print("\nFetching model from: ", query_dir,'\n')

In [None]:

dtfa = DTFA.DeepTFA(avfp_db, num_factors=n_factor, linear_params=linear_opts, query_name=query_dir, voxel_noise = 0.7652)


In [None]:
# plot the initialized - kmeans - spatial factors (from dtfa.visualize_factor_embedding)

fig = plt.figure(figsize=(7,3))
results = dtfa.results()
centers = results['factor_centers']
widths = torch.exp(results['factor_log_widths'])
plot = plotting.plot_connectome(
    np.eye(dtfa.num_factors),
    node_coords = centers.view(dtfa.num_factors, 3).numpy(),
    node_size = widths.view(dtfa.num_factors).numpy(),
    figure=fig
)
plt.suptitle('Spatial Factors: Initialized', size=20)
#plt.savefig(query_dir + 'factors_initialized.pdf')
plt.show()

In [None]:
# load mask:
mask = image.load_img(mask_dir + 'gm_mask_icbm152_brain.nii.gz')
mask_data = mask.get_fdata()

In [None]:
def check_factor_locations():
    count = 0
    centers_np = centers.view(dtfa.num_factors, 3).numpy()
    for i in range(centers_np.shape[0]):
        coords = centers_np[i,:]
        # apply inverse affine to go from world to voxel coordinates:
        voxel_coords = np.round(image.coord_transform(coords[0], coords[1], coords[2], 
                                             np.linalg.inv(mask.affine))
                               ).astype(int)
        center_value = mask_data[voxel_coords[0], voxel_coords[1], voxel_coords[2]]
        if center_value == 0: #if not in mask
            count +=1
            print('Factor center:',coords,voxel_coords)   
    if count == 0:
        print('\nAll spatial factors are within the mask\n')
    else:
        print('\n',count,'factors are outside of the mask\n')

In [None]:
# name of the most recent model (prefix for .dtfa_model and .dtfa_guide)
checkpoint_files = glob.glob(query_dir + 'CHECK*dtfa*')
state_name = max(checkpoint_files, key=os.path.getctime).split('.dtfa')[0]
print('\nLoading most recent checkpoint:',state_name,'\n')

dtfa.load_state(state_name)

# Spatial Factor

In [None]:
# plot the learned spatial factors - this will vary by subject
block_plot = 0

fig = plt.figure(figsize=(7,3))
results = dtfa.results(block=block_plot)
centers = results['factor_centers']
widths = torch.exp(results['factor_log_widths'])
plot = plotting.plot_connectome(
    np.eye(dtfa.num_factors),
    node_coords = centers.view(dtfa.num_factors, 3).numpy(),
    node_size = widths.view(dtfa.num_factors).numpy(),
    figure=fig
)
plt.suptitle('Spatial Factors (Block: ' + str(block_plot) + ')', size=20)
plt.show()

In [None]:

hyperparams = dtfa.variational.hyperparams.state_vardict()
tasks = dtfa.tasks()
subjects = dtfa.subjects()
z_p_mu = hyperparams['subject_weight']['mu'].data
z_s_mu = hyperparams['task']['mu'].data

In [None]:
z_ps_mu, combinations = list(), list()
for p in range(len(subjects)):
    # find index:
    sub_tasks = [b['task'] for b in avfp_db.blocks.values() if b['subject'] == subjects[p]]
    combinations.append(np.vstack([np.repeat(subjects[p],len(sub_tasks)), np.array(sub_tasks)]))
    for t in range(len(sub_tasks)):
        task_index = [i for i, e in enumerate(tasks) if e == sub_tasks[t]]
        joint_embed = torch.cat((z_p_mu[p], z_s_mu[task_index[0]]), dim=-1)
        interaction_embed = dtfa.decoder.interaction_embedding(joint_embed).data
        z_ps_mu.append(interaction_embed.data.numpy())
z_ps_mu = np.vstack(z_ps_mu)   
combinations = np.hstack(combinations).T  

In [None]:
def fetch_embeddings(): 
    hyperparams = dtfa.variational.hyperparams.state_vardict()
    tasks = dtfa.tasks()
    subjects = dtfa.subjects()
    z_p_mu = hyperparams['subject_weight']['mu'].data
    z_s_mu = hyperparams['task']['mu'].data

    z_ps_mu, combinations = list(), list()
    for p in range(len(subjects)):
        # because I coded by memory, participants only have 1/2 of the unqiue tasks each - find index:
        sub_tasks = [b['task'] for b in avfp_db.blocks.values() if b['subject'] == subjects[p]]
        combinations.append(np.vstack([np.repeat(subjects[p],len(sub_tasks)), np.array(sub_tasks)]))
        for t in range(len(sub_tasks)):
            task_index = [i for i, e in enumerate(tasks) if e == sub_tasks[t]]
            joint_embed = torch.cat((z_p_mu[p], z_s_mu[task_index[0]]), dim=-1)
            interaction_embed = dtfa.decoder.interaction_embedding(joint_embed).data
            z_ps_mu.append(interaction_embed.data.numpy())
    z_ps_mu = np.vstack(z_ps_mu)   
    combinations = np.hstack(combinations).T  

    # convert to dataframes
    z_p = pd.DataFrame(np.hstack([np.reshape(subjects, (len(subjects),1)), z_p_mu.numpy()]),
                       columns=['participant','x','y'])
    z_s = pd.DataFrame(np.hstack([np.reshape(tasks, (len(tasks),1)), z_s_mu.numpy()]),
                       columns=['stimulus','x','y'])
    z_ps = pd.DataFrame(np.hstack([combinations, z_ps_mu]),
                        columns=['participant','stimulus','x','y'])
    return z_p, z_s, z_ps

In [None]:
p_embedding, s_embedding, c_embedding = fetch_embeddings()

In [None]:
p_embedding.participant = p_embedding.participant.astype('int').astype('string')
p_embedding[['x','y']] = p_embedding[['x','y']].astype('float')
p_embedding.head()

In [None]:
s_embedding[['x','y']] = s_embedding[['x','y']].astype('float')
s_embedding.head()

In [None]:
c_embedding.participant = c_embedding.participant.astype('int').astype('string')
c_embedding[['x','y']] = c_embedding[['x','y']].astype('float')
c_embedding.head()

In [None]:


# save embedding coordinates in pickle:
p_embedding.to_pickle(query_dir + 'p_embedding.pkl')
s_embedding.to_pickle(query_dir + 's_embedding.pkl')
c_embedding.to_pickle(query_dir + 'c_embedding.pkl')

In [None]:
# generalizable function to plot embeddings and colour-code by a variable
# this variable should have been added to the embedding df
# N embedding x 2 coordinates
def plot_embeddings(embeddings, hue_var, label_vars=[], marker=16):
    if 'participant' in embeddings.columns:
        hover_vars = ['participant',hue_var]
    else: hover_vars = [hue_var]
    hover_vars.extend(label_vars)
    
    if 'z' not in embeddings.columns:
        fig = px.scatter(embeddings, x='x', y='y',
                         hover_data=hover_vars,color=hue_var,color_discrete_sequence=px.colors.qualitative.Light24)
        fig.update_layout(height=500, width=600,               
                  xaxis=dict(linecolor='black',mirror=True,linewidth=2,
                             tickfont=dict(size=16, color='black'),
                             titlefont=dict(size=20)),
                  yaxis=dict(linecolor='black',mirror=True,linewidth=2,
                             tickfont=dict(size=16, color='black'),
                             titlefont=dict(size=20))
                 )
    else:
        fig = px.scatter_3d(embeddings, x='x', y='y', z='z',
                         hover_data=hover_vars,color=hue_var)
        fig.update_layout(height=500, width=700, 
          scene_aspectmode='cube',
          scene = dict(
              xaxis=dict(linecolor='black',mirror=True,linewidth=2,
                         tickfont=dict(size=16, color='black'),
                         titlefont=dict(size=20)),
              yaxis=dict(linecolor='black',mirror=True,linewidth=2,
                         tickfont=dict(size=16, color='black'),
                         titlefont=dict(size=20)),
              zaxis=dict(linecolor='black',mirror=True,linewidth=2,
                         tickfont=dict(size=16, color='black'),
                         titlefont=dict(size=20)),
          ), margin=dict(l=0.1, r=0.1, b=0.1, t=0.1)
         )
    fig.update_traces(marker=dict(size=marker))
    fig.show()
    return fig

In [None]:
fig = plot_embeddings(p_embedding, 'participant', marker=24)
#fig.write_html(query_dir + 'participant_embeddings.html')

In [None]:
situations = ['Heights','Social','Spiders']
s_embedding['situation'] = ''
for i in list(s_embedding.index):
    for s in situations:
        if s in s_embedding.stimulus[i]:
            s_embedding.loc[i,'situation']=s
            
fig = plot_embeddings(s_embedding, 'stimulus', label_vars=['stimulus'])

In [None]:
situations = ['Heights','Social','Spiders']
c_embedding['situation'] = ''
for i in list(c_embedding.index):
    for s in situations:
        
        if s in c_embedding.stimulus[i]:
            c_embedding.loc[i,'situation']=s
            
fig = plot_embeddings(c_embedding, 'participant', label_vars=['stimulus'], marker=8)
#fig.write_image(query_dir + 'combination_embeddings.jpeg', scale=2)

In [None]:
situations = ['Heights','Social','Spiders']
c_embedding['situation'] = ''
for i in list(c_embedding.index):
    for s in situations:
        
        if s in c_embedding.stimulus[i]:
            c_embedding.loc[i,'situation']=s
            
fig = plot_embeddings(c_embedding, 'stimulus', label_vars=['stimulus'], marker=8)
#fig.write_image(query_dir + 'combination_embeddings.jpeg', scale=2)

# Reconstruction

In [None]:
tr_data = torch.utils.data.DataLoader(dtfa._dataset.data(selector=lambda block: True), batch_size=128, pin_memory=True)

In [None]:
tr_activations, tr_blocks = [], []
for tr in tr_data:
    tr_activations.append(tr['activations'].numpy())
    tr_blocks.extend(list(tr['block'].numpy()))
tr_activations = np.concatenate(tr_activations, axis=0)
tr_activations.shape

In [None]:
# fetch *all* reconstructions
reconstructions = []
for b in np.unique(tr_blocks):
    results = dtfa.results(block=b, generative=True)
    # note, results contains the embedding info and factor weights
    voxel_values = (results['weights'] @ results['factors']).numpy()
    reconstructions.append(voxel_values)
reconstructions = np.concatenate(reconstructions, axis=0)
reconstructions.shape

In [None]:
# custom functions for plotting the original brain and reconstruction
# (adapted from dtfa and utils functions)

def fetch_block_index(subjects=None, task=None, measure=None, 
                      value=None, value_type='exact', value_direction='higher'):
    # block IDs belonging subject(s) and task/split
    # will also be grouped by cluster ID at some point...
    # measure should be modeled as a trial-specific "individual differences" (fix that phrasing in db script at a later date...)
    # and will be returned based on value and value_type ('exact' — find that value — or 'quantile')
    
    if subjects is not None:
        blocks = [b for b in avfp_db.blocks.values() for s in subjects if b['subject'] == s]
    else: blocks = [b for b in avfp_db.blocks.values()]
    if task is not None:
        blocks = [b for b in blocks for t in task if t in b['task']]
    if measure is not None:
        if value_type == 'exact':
            blocks = [b for b in blocks if b['individual_differences'][measure] == value]
        elif value_type == 'quantile': 
            #note - abs() here is specific to our fear ratings - might want to adjust later on
            # also note that quantile is across all blocks selected, so may need to zscore first
            measure_values = np.array([np.abs(b['individual_differences'][measure]) for b in blocks])
            measure_split = np.quantile(measure_values[~np.isnan(measure_values)], value)
            if value_direction == 'higher':
                blocks = [b for b in blocks if b['individual_differences'][measure] > measure_split]
            elif value_direction == 'lower':
                blocks = [b for b in blocks if b['individual_differences'][measure] <= measure_split]
        
    return [b['id'] for b in blocks]

In [None]:
def create_brain_image(blocks, method):
    # blocks = ids of blocks to average
    activations = []
    block_index = [i for i, e in enumerate(tr_blocks) if e in blocks]        
    for b in blocks:
        if method == 'original':
            block_index = [i for i, e in enumerate(tr_blocks) if e == b]
            voxel_values = tr_activations[block_index,:]
        elif method == 'reconstruction':
            # note, results contains the embedding info and factor weights
            results = dtfa.results(block=b, generative=True)
            voxel_values = (results['weights'] @ results['factors']).numpy()
        activations.append(voxel_values)
    #mean across all TRs/blocks
    activations = np.concatenate(activations, axis=0)
    activations = np.mean(activations, axis=0, keepdims=True) 

    #to nifti image
    brain_image = utils.cmu2nii(activations,
                          dtfa.voxel_locations.numpy(),
                          dtfa._dataset.blocks[b]['template'])
    return brain_image, activations

def plot_brain(blocks, method, ax, zbound=3):
    # plot mean contrast image
    brain_image, activations = create_brain_image(blocks, method)
    plot = plotting.plot_glass_brain(brain_image, plot_abs=False, colorbar=True,
                                     symmetric_cbar=True, vmax=zbound, axes=ax)
    ax.set_title(method, fontsize=16)
    return activations

    
def plot_activations(blocks, subjects=None,
                     task=None, measure=None, value=None):
    # plot mean original and reconstruction
    plt.figure(figsize=(12,2.5))

    ax = plt.subplot(1, 2, 1)
    orig = plot_brain(blocks, 'original', ax, zbound=1.5)
    ax = plt.subplot(1, 2, 2)
    recon = plot_brain(blocks, 'reconstruction', ax, zbound=1.5)
    
    # correlate (adjusting for the correlation between the block and other reconstructions)
    corr = np.round(np.corrcoef(orig[0,:],recon[0,:])[0,1],3)
    other_index = [i for i, e in enumerate(tr_blocks) if e not in blocks]
    other_values = np.mean(reconstructions[other_index,:], axis=0, keepdims=True)
    mean_corr = np.round(np.corrcoef(orig[0,:],other_values[0,:])[0,1],3)

    plt.subplots_adjust(wspace=0.1)
    title = str(subjects) + ' ' + str(task) + ' ' + str(measure) + ' ' + str(value) 
    title = title.replace("None","")
    plt.suptitle(title + "   r = " + str(corr) + " (mean r = " + str(mean_corr) + ")", fontsize=18, y=1.1)
    plt.show()    

In [None]:
def calculate_difference(blocks, absolute_difference=False):
    activations = []
    reconstructions = []
    differences =[]
    for b in blocks:
        
        block_index = [i for i, e in enumerate(tr_blocks) if e == b]
        tr_values = tr_activations[block_index,:]
        
            # note, results contains the embedding info and factor weights
        results = dtfa.results(block=b, generative=True)
        recon_values = (results['weights'] @ results['factors']).numpy()
        
        activations.append(tr_values)
        reconstructions.append(recon_values)
        
        if absolute_difference:
            differences.append([abs(np.mean(tr_values, axis=0) - np.mean(recon_values, axis =0))])
        else:
            differences.append([np.mean(tr_values, axis=0) - np.mean(recon_values, axis =0)])
    activations = np.concatenate(activations, axis=0)
    #activations = np.mean(activations, axis=0, keepdims=True)
    
    reconstructions = np.concatenate(reconstructions, axis=0)
    #reconstructions = np.mean(reconstructions, axis=0, keepdims=True)
    
    differences = np.concatenate(differences)
    
    brain_image = utils.cmu2nii(differences,
                          dtfa.voxel_locations.numpy(),
                          dtfa._dataset.blocks[b]['template'])
    
    return activations, reconstructions, differences, brain_image  
        
def plot_voxel_noise(blocks, subjects=None,task=None, plot_individual=False, plot_average=False,absolute_difference=False):
    avg_list = []
    
    activations, reconstructions, differences, diff_img = calculate_difference(blocks, absolute_difference)
    for i in range(len(blocks)):
        slice_data= differences[i,:]
        std = np.round(np.std(slice_data),3)
        avg = np.round(np.mean(slice_data),3)
        
        
        avg_list.append(avg)
        
        if plot_individual:
            value = '[blocks- ' + str(blocks[i])+ ']'
            _, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.hist(slice_data, alpha=.8, bins=128)
            title = str(subjects) + ' ' + str(task) + ' '+ value + ' mean = '+ str(avg) + ' std = ' + str(std)
            title = title.replace("None","")
            ax.set_title(title)
            ax.grid("on")
        
    if plot_average:
        _, ax1 = plt.subplots(1, 1, figsize=(5, 5))
        ax1.hist(avg_list, alpha=.8, bins=128)
        title = 'distribution of mean difference for each block' + ': mean (of mean) = '+ str(np.round(np.mean(avg_list),3)) + ' sem = ' + str(np.round(np.std(avg_list),3))
        title = title.replace("None","")
        ax1.set_title(title)
        ax1.grid("on")
            
    return avg_list



In [None]:
# blocks_to_plot = fetch_block_index(task=['Heights'])
# avg_list = plot_voxel_noise(blocks_to_plot)

In [None]:
# blocks_to_plot = fetch_block_index(task=['Heights'])
# avg_list = plot_voxel_noise(blocks_to_plot,absolute_difference=True)

In [None]:
# this_sub = 107
# blocks_to_plot = fetch_block_index(subjects=[this_sub])
# avg_list_107 = plot_voxel_noise(blocks_to_plot, absolute_difference=True,plot_individual=True)

In [None]:
# blocks_to_plot = fetch_block_index(task=['New_Heights_12_AVFP'])
# h12= plot_voxel_noise(blocks_to_plot, absolute_difference=True,plot_individual=True)

In [None]:
# this_sub = 103
# blocks_to_plot = fetch_block_index(subjects=[this_sub])
# sub_103_list = plot_voxel_noise(blocks_to_plot, absolute_difference=True,plot_individual=True)

In [None]:
hold_out_data = True
n_per_subj = 2

if hold_out_data:
    rng = np.random.default_rng(2022)
    test_blocks = []
    for p in avfp_db.subjects():
        sub_tasks = [b['task'] for b in avfp_db.blocks.values() if b['subject'] == p]
        idx = rng.choice(len(sub_tasks), n_per_subj, replace=False)
        for i in idx:
            test_blocks.extend([b['id'] for b in avfp_db.blocks.values() if (b['subject'] == p) & (b['task'] == sub_tasks[i])])
    test_blocks = np.sort(test_blocks).tolist()
    print('Using',len(test_blocks),'blocks for testing\nIDs:',test_blocks)  

In [None]:
training_filter = avfp_db.inference_filter_blocks(training=True, exclude_blocks=test_blocks)
training_blocks = [b for (b, block) in dtfa._dataset.blocks.items() if training_filter(block)]
print(training_blocks)

In [None]:
validation_filter = avfp_db.inference_filter_blocks(training=False, exclude_blocks=test_blocks)
testing_blocks = [b for (b, block) in dtfa._dataset.blocks.items() if validation_filter(block)]
print(testing_blocks)

In [None]:
plot_activations(testing_blocks)

In [None]:
plot_activations(training_blocks)

In [None]:
plot_activations([training_blocks[1]])

In [None]:
plot_activations([training_blocks[40]])

In [None]:
plot_activations([testing_blocks[1]])

In [None]:
plot_activations([testing_blocks[5]])

In [None]:
# possible options: subject list, task list, measure string, value float

blocks_to_average = fetch_block_index(task=['Heights'])
plot_activations(blocks_to_average, task=['Heights'])

In [None]:
blocks_to_average = fetch_block_index(task=['New_Heights_11_AVFP'])
plot_activations(blocks_to_average, task=['New_Heights_11_AVFP'])

In [None]:
blocks_to_average = fetch_block_index(task=['New_Heights_10_AVFP'])
plot_activations(blocks_to_average, task=['New_Heights_10_AVFP'])

In [None]:
blocks_to_average = fetch_block_index(task=['New_Heights_12_AVFP'])
plot_activations(blocks_to_average, task=['New_Heights_12_AVFP'])

In [None]:
blocks_to_average = fetch_block_index(task=['New_Heights_2_AVFP'])
plot_activations(blocks_to_average, task=['New_Heights_2_AVFP'])

In [None]:
this_sub = 103
blocks_to_average = fetch_block_index(subjects=[this_sub])
plot_activations(blocks_to_average, subjects=[this_sub])  

In [None]:

this_sub = 107
blocks_to_average = fetch_block_index(subjects=[this_sub])
plot_activations(blocks_to_average, subjects=[this_sub]) 

In [None]:
this_sub = 111
blocks_to_average = fetch_block_index(subjects=[this_sub])
plot_activations(blocks_to_average, subjects=[this_sub]) 

In [None]:
this_sub = 119
blocks_to_average = fetch_block_index(subjects=[this_sub])
plot_activations(blocks_to_average, subjects=[this_sub]) 

In [None]:
this_sub = 118
blocks_to_average = fetch_block_index(subjects=[this_sub])
plot_activations(blocks_to_average, subjects=[this_sub]) 

# Validations

In [None]:
def embedding_activity_plot(embedding_dist, activation_dist, r):
    # standardize for plotting:
    mask = np.ones((embedding_dist.shape[0],embedding_dist.shape[0])) 
    mask = (mask - np.diag(np.ones(embedding_dist.shape[0]))).astype(np.bool)

    embedding_dist[mask] = (embedding_dist[mask] - np.mean(embedding_dist[mask])) / np.std(embedding_dist[mask])
    activation_dist[mask] = (activation_dist[mask] - np.mean(activation_dist[mask])) / np.std(activation_dist[mask])
    
    # visualize matrices:
    fig, axes = plt.subplots(1, 2, figsize=(9,3.25), 
                             gridspec_kw={'wspace':0.2})
    plt.axes(axes[0])
    sns.heatmap(embedding_dist, square=True,
                center=0, vmax=3, cmap="icefire_r")
    plt.title('Embedding distances', size=18, y=1.05)
    plt.ylim(embedding_dist.shape[0],0)

    plt.axes(axes[1])
    sns.heatmap(activation_dist, square=True,
                center=0, vmax=3, cmap="icefire_r")
    plt.title('Activity distances', size=18, y=1.05)
    plt.ylim(activation_dist.shape[0],0)
    
    plt.text(-0.4, 1.1, 'r= ' + str(np.round(r,3)), fontweight='semibold',
             transform=axes[1].transAxes, size=12, fontstyle='italic')
    plt.show()

In [None]:
# fetch mean activations by participant as p x v
p_activations = np.empty((len(dtfa.subjects()),dtfa.num_voxels))

for i in range(p_embedding.shape[0]):
    blocks_to_average = fetch_block_index(subjects=[int(p_embedding.participant[i])])
    _, i_activations = create_brain_image(blocks_to_average, 'original')
    p_activations[i,:] = i_activations[0,:]
    
p_activations.shape

In [None]:
# activation and embedding dissimilarity (p x p matrices):
embedding_dist = squareform(pdist(p_embedding[['x','y']]))
activation_dist = squareform(pdist(p_activations))

In [None]:
# correlate the lower triangles
l_idx = np.tril_indices(embedding_dist.shape[0], k=-1)
corr, _ = spearmanr(embedding_dist[l_idx],activation_dist[l_idx])
print('Spearmans correlation between participant embedding and activity distances: %.3f' % corr)

In [None]:
embedding_activity_plot(embedding_dist, activation_dist, corr)

## Stimulus embedding

In [None]:
s_embedding_sort = s_embedding.sort_values('situation').reset_index(drop=True)
s_activations = np.empty((len(dtfa.tasks()),dtfa.num_voxels))

for i in range(s_embedding_sort.shape[0]):
    blocks_to_average = fetch_block_index(task=[s_embedding_sort.stimulus[i]])
    _, i_activations = create_brain_image(blocks_to_average, 'original')
    s_activations[i,:] = i_activations[0,:]
    
s_activations.shape

In [None]:
# activation and embedding dissimilarity (p x p matrices):
embedding_dist = squareform(pdist(s_embedding_sort[['x','y']]))
activation_dist = squareform(pdist(s_activations))

In [None]:
# correlate the lower triangles
l_idx = np.tril_indices(embedding_dist.shape[0], k=-1)
corr, _ = spearmanr(embedding_dist[l_idx],activation_dist[l_idx])
print('Spearmans correlation between stimulus embedding and activity distances: %.3f' % corr)


In [None]:
embedding_activity_plot(embedding_dist, activation_dist, corr)

## Combination embedding

In [None]:
c_activations = np.empty((c_embedding.shape[0],dtfa.num_voxels))

for i in range(c_embedding.shape[0]):
    block = fetch_block_index(subjects=[int(c_embedding.participant[i])],
                              task=[c_embedding.stimulus[i]])
    if len(block) > 1:
        print('warning: more than one block found for this trial')
    _, i_activations = create_brain_image(block, 'original')
    c_activations[i,:] = i_activations[0,:]
    
c_activations.shape

In [None]:
# activation and embedding dissimilarity (p x p matrices):
embedding_dist = squareform(pdist(c_embedding[['x','y']]))
activation_dist = squareform(pdist(c_activations))

In [None]:
# correlate the lower triangles
l_idx = np.tril_indices(embedding_dist.shape[0], k=-1)
corr, _ = spearmanr(embedding_dist[l_idx],activation_dist[l_idx])
print('Spearmans correlation between combination embedding and activity distances: %.3f' % corr)


In [None]:
embedding_activity_plot(embedding_dist, activation_dist, corr)

In [None]:
# save the embedding coordinate:
        p_embedding, s_embedding, c_embedding = fetch_embeddings()

        p_embedding.participant = p_embedding.participant.astype('int').astype('string')
        p_embedding[['x','y']] = p_embedding[['x','y']].astype('float')

        s_embedding[['x','y']] = s_embedding[['x','y']].astype('float')

        c_embedding.participant = c_embedding.participant.astype('int').astype('string')
        c_embedding[['x','y']] = c_embedding[['x','y']].astype('float')



        # save embedding information in pickle
        p_embedding.to_pickle(query_dir + 'p_embedding.pkl')
        s_embedding.to_pickle(query_dir + 's_embedding.pkl')
        c_embedding.to_pickle(query_dir + 'c_embedding.pkl')



In [None]:
# save
html_name = query_dir + f'NTFA_visualization_{linear_opts}.html'
%store html_name