# Saving Reconstruction from NTFA into nii files

Yiyu Wang

2023/05

In [4]:
# check we're in our env (*)
%conda env list

# conda environments:
#
HTFATorch                /home/wang.yiyu/.conda/envs/HTFATorch
NTFA_env3             *  /home/wang.yiyu/.conda/envs/NTFA_env3
base                     /shared/centos7/anaconda3/3.7
                         /work/abslab/Yiyu/DNN_env


Note: you may need to restart the kernel to use updated packages.


In [5]:
import matplotlib.pyplot as plt
import pickle
import glob
import os
import numpy as np
import pandas as pd
from nilearn import plotting, image
import nibabel as nib
import itertools

mask_dir = 'masks/'



In [6]:
# **** parameters that define the model directory ******
# path for the NTFA package
which_ntfa_model = 'v2'

NTFA_path = "NTFA_v2/"
print("NTFA code from:", NTFA_path)
import sys
sys.path.append(NTFA_path)




subs = 'All' #20 #'All' #note, database file must have been created already for these subjects

included_data = pd.read_csv(base_dir + 'fmri_info/included_avfp_subjects.csv', header=None)
subIDs = included_data[0].astype('str').tolist()
print(subIDs)
total_subs = len(subIDs)
print(f"total subs = {total_subs}")

# using GM (and SNR) or SNR only?
mask_type = 'GMgroup'
# load mask:
if mask_type == 'GMgroup':
    mask_file = mask_dir +f'GM_fmriprep_novelgroup_mask_N71.nii.gz'
else:
    mask_file = mask_dir + 'gm_mask_icbm152_brain.nii.gz'

print('using mask: ', mask_file)


# penalty weights (participant, stimulus, combination)
p_weight, s_weight, c_weight = 1, 1, 1
linear_opts = 'None' # 'C', 'PSC' 'None'
# additional parameters:
n_epoch = 2000
n_factor = 100
n_check = 50 #

# load mask:
mask = image.load_img(mask_file)
mask_data = mask.get_fdata()

SEED = 2023

NTFA code from: /work/abslab/NTFA_packages/NTFADegeneracy_merged/
['100', '103', '104', '105', '106', '107', '108', '109', '111', '112', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '127', '128', '130', '131', '132', '134', '135', '136', '137', '138', '139', '140', '142', '143', '144', '145', '146', '149', '150', '151', '152', '153', '154', '157', '158', '159', '160', '161', '162', '163', '164', '165', '166', '167', '169', '170', '171', '172', '174', '175', '176', '177', '179', '181', '182', '183', '184', '185', '186']
total subs = 71
using mask:  /work/abslab/AVFP/NTFA/masks/GM_fmriprep_novelgroup_mask_N71.nii.gz


In [7]:
# set up directory and filename for saving model


query_dir = f'models/AVFP_NTFA_sub-{subs}_epoch-{n_epoch}_factor-{n_factor}_mask-{mask_type}_{p_weight}{s_weight}{c_weight}_lin-{linear_opts}_ntfa-{which_ntfa_model}_visreg/'

print("\nFetching model from: ", query_dir,'\n')




Fetching model from:  models/AVFP_NTFA_sub-All_epoch-2000_factor-500_mask-GMgroup_111_lin-None_ntfa-v2_visreg/ 



In [8]:
# import ntfa functions
import htfa_torch.dtfa as DTFA
import htfa_torch.niidb as niidb
import htfa_torch.utils as utils
import htfa_torch.tardb as tardb

import torch


In [7]:
# define database filename

AVFP_FILE = f'data/AVFP_NTFA_N-{total_subs}_mask-{mask_type}.tar'
print('\nFetching database:',AVFP_FILE)


# this step can take a few minutes (15min for avfp ~2556 trials)
avfp_db = tardb.FmriTarDataset(AVFP_FILE)


Fetching database: /work/abslab/Yiyu/NTFA_AVFP/data/AVFP_NTFA_N-71_mask-GMgroup.tar


In [8]:
# create the DTFA object for avfp_database
# again depending on the data size, it can take sometime
dtfa = DTFA.DeepTFA(avfp_db, num_factors=n_factor, linear_params=linear_opts, query_name=query_dir)
n_blocks = dtfa.num_blocks

In [9]:

# get the most recent model (prefix for .dtfa_model and .dtfa_guide)
checkpoint_files = glob.glob(query_dir + 'CHECK*dtfa*')
state_name = max(checkpoint_files, key=os.path.getctime).split('.dtfa')[0]
print('\nLoading most recent checkpoint:',state_name,'\n')

dtfa.load_state(state_name)


Loading most recent checkpoint: models/AVFP_NTFA_sub-All_epoch-2000_factor-500_mask-GMgroup_111_lin-None_ntfa-v2_visreg/CHECK_06042023_192120_Epoch1085 



# Create Nii for original activations and reconstructions

In [None]:
# this process can take about 1-2 hours
for block_i in range(n_blocks):
    # get block info:
    tr = dtfa._dataset.blocks[block_i]
    file_name = f"sub-{tr['subject']}_run-{tr['run']}_video-{tr['task']}.nii.gz"

    # print progress
    if block_i % 50 ==0:
        print(block_i)
    
    # extract reconstructions:
    if os.path.exists(os.path.join(query_dir, 'reconstruction/'+ file_name)):
        continue
        
    else:
        # dtfa.results contain information about the factor weights and locations
        results = dtfa.results(block=block_i, generative=True)
        # calculate reconstruction:
        recon = (results['weights'] @ results['factors']).numpy()
        recon = np.mean(recon, axis=0, keepdims=True) 
        #to nifti image
        recon_brain_image = utils.cmu2nii(recon,
                              dtfa.voxel_locations.numpy(),
                              tr['template'])
        nib.save(recon_brain_image, os.path.join(query_dir, 'reconstruction/'+ file_name))  

In [None]:
for block_i in range(n_blocks):

    # get block info
    tr = dtfa._dataset.blocks[block_i]
    file_name = f"sub-{tr['subject']}_run-{tr['run']}_video-{tr['task']}.nii.gz"

    # print progress
    if block_i % 50 ==0:
        print(block_i)

    # extract activations     
    if os.path.exists(os.path.join(query_dir, 'activations/'+ file_name)):
        continue
    else:
        #get the original activations that were used for ntfa training
        # dtfa_tr contain each block neural data from dtfa._dataset
        activations = dtfa._dataset[block_i]['activations'].numpy()
        activations = np.mean(activations, axis = 0, keepdims = True)

        act_brain_image = utils.cmu2nii(activations,
                              dtfa.voxel_locations.numpy(),
                              tr['template'])
        nib.save(act_brain_image, os.path.join(query_dir, 'activations/'+ file_name))  
    

# create embedding files

In [None]:
def fetch_embeddings_v1(): 
    hyperparams = dtfa.variational.hyperparams.state_vardict()
    tasks = dtfa.tasks()
    subjects = dtfa.subjects()
    z_p_mu = hyperparams['subject_weight']['mu'].data
    z_s_mu = hyperparams['task']['mu'].data

    z_ps_mu, combinations = list(), list()
    for p in range(len(subjects)):
        # because I coded by memory, participants only have 1/2 of the unqiue tasks each - find index:
        sub_tasks = [b['task'] for b in avfp_db.blocks.values() if b['subject'] == subjects[p]]
        combinations.append(np.vstack([np.repeat(subjects[p],len(sub_tasks)), np.array(sub_tasks)]))
        for t in range(len(sub_tasks)):
            task_index = [i for i, e in enumerate(tasks) if e == sub_tasks[t]]
            joint_embed = torch.cat((z_p_mu[p], z_s_mu[task_index[0]]), dim=-1)
            interaction_embed = dtfa.decoder.interaction_embedding(joint_embed).data
            z_ps_mu.append(interaction_embed.data.numpy())
    z_ps_mu = np.vstack(z_ps_mu)   
    combinations = np.hstack(combinations).T  

    # convert to dataframes
    z_p = pd.DataFrame(np.hstack([np.reshape(subjects, (len(subjects),1)), z_p_mu.numpy()]),
                       columns=['participant','x','y'])
    z_s = pd.DataFrame(np.hstack([np.reshape(tasks, (len(tasks),1)), z_s_mu.numpy()]),
                       columns=['stimulus','x','y'])
    z_ps = pd.DataFrame(np.hstack([combinations, z_ps_mu]),
                        columns=['participant','stimulus','x','y'])
    return z_p, z_s, z_ps

def fetch_embeddings_v2(): 
    hyperparams = dtfa.variational.hyperparams.state_vardict()
    tasks = dtfa.tasks()
    subjects = dtfa.subjects()
    interactions = dtfa._interactions
    z_p_mu = hyperparams['subject_weight']['mu'].data
    z_s_mu = hyperparams['task']['mu'].data
    z_i_mu = hyperparams['interaction']['mu'].data
    
    z_p_sigma = torch.exp(hyperparams['subject_weight']['log_sigma'].data)
    z_s_sigma = torch.exp(hyperparams['task']['log_sigma'].data)
    z_i_sigma = torch.exp(hyperparams['interaction']['log_sigma'].data)

    # convert to dataframes
    z_p = pd.DataFrame(np.hstack([np.reshape(subjects, (len(subjects),1)), z_p_mu.numpy(), z_p_sigma.numpy()]),
                       columns=['participant','x','y', 'x_sigma','y_sigma'])
    z_s = pd.DataFrame(np.hstack([np.reshape(tasks, (len(tasks),1)), z_s_mu.numpy(), z_s_sigma.numpy()]),
                       columns=['stimulus','x','y', 'x_sigma', 'y_sigma'])
    z_ps = pd.DataFrame(np.hstack([interactions, z_i_mu.numpy(), z_i_sigma.numpy()]),
                        columns=['participant','stimulus','x','y', 'x_sigma', 'y_sigma'])
    return z_p, z_s, z_ps

In [None]:
# name of the most recent model (prefix for .dtfa_model and .dtfa_guide)
checkpoint_files = glob.glob(query_dir + 'CHECK*dtfa*')
state_name = max(checkpoint_files, key=os.path.getctime).split('.dtfa')[0]
print('\nLoading most recent checkpoint:',state_name,'\n')

dtfa.load_state(state_name)
if which_ntfa_model == 'v1': 
    p_embedding, s_embedding, c_embedding = fetch_embeddings_v1()
elif which_ntfa_model == 'v2': 
    p_embedding, s_embedding, c_embedding = fetch_embeddings_v2()
else:
    raise Exception("Specify or check the NTFA version you are running!")

p_embedding.participant = p_embedding.participant.astype('int').astype('string')
p_embedding[['x','y']] = p_embedding[['x','y']].astype('float')

s_embedding[['x','y']] = s_embedding[['x','y']].astype('float')

c_embedding.participant = c_embedding.participant.astype('int').astype('string')
c_embedding[['x','y']] = c_embedding[['x','y']].astype('float')

# save embedding information in pickle
p_embedding.to_pickle(query_dir + 'p_embedding.pkl')
s_embedding.to_pickle(query_dir + 's_embedding.pkl')
c_embedding.to_pickle(query_dir + 'c_embedding.pkl')
print("\nNew embedding pkl created at: ", query_dir)