In [1]:
# check we're in our env (*)
%conda env list

# conda environments:
#
HTFATorch                /home/wang.yiyu/.conda/envs/HTFATorch
NTFA_env3             *  /home/wang.yiyu/.conda/envs/NTFA_env3
base                     /shared/centos7/anaconda3/3.7


Note: you may need to restart the kernel to use updated packages.


In [2]:
# path for the NTFA package
#NTFA_path = "/work/abslab/NTFA_packages/NTFADegeneracy/" #combination embedding NTFA
#NTFA_path = "/work/abslab/NTFA_packages/NTFATorch/" #base NTFA (just P and S)
NTFA_path = "/work/abslab/Yiyu/NTFA_packages/ntfa_degeneracy/"
print("Using NTFA code from:", NTFA_path)

import sys
sys.path.append(NTFA_path)
import htfa_torch.dtfa as DTFA
import htfa_torch.niidb as niidb
import htfa_torch.tardb as tardb
import htfa_torch.utils as utils
import logging
import numpy as np
import pandas as pd
import glob
import os
import re
import webdataset as wds

import torch
import itertools
import timeout_decorator
import matplotlib.pyplot as plt
from ordered_set import OrderedSet

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

Using NTFA code from: /work/abslab/Yiyu/NTFA_packages/ntfa_degeneracy/




In [3]:
nifti_dir = '/work/abslab/AVFP/denoised/'
logfiles_dir = '/work/abslab/AVFP/logfiles/AffVidsNovel_logfiles/'

mask_dir = '/home/wang.yiyu/AVFP/masks/'
base_dir = '/work/abslab/Yiyu/NTFA_AVFP/'

In [4]:
# **** parameters that define the model directory ******
subs = 20 #'All' #note, database file must have been created already for these subjects


included_data = pd.read_csv(base_dir + 'fmri_info/included_avfp_subjects.csv', header=None)
subIDs = included_data[0].astype('str').tolist()
print(subIDs)
total_subs = len(subIDs)
print(f"total subs = {total_subs}")

# using GM (and SNR) or SNR only?
mask_type = 'GM' #'GMandSNR' #'SNR' or 'GM'

# penalty weights (participant, stimulus, combination)
p_weight, s_weight, c_weight = 1, 1, 1
tuning = False    # if True, it uses the existing baseline model to further train
# while *keeping the factor weights consistent*

# which embedding-to-weight mappings should be linear?
linear_opts = 'None' #'Base' 'None' 'C'
# train based on a penalty weight combination
# ******************************************************

# additional parameters:
n_epoch = 1000
n_factor = 100
n_check = 50 # save checkpoints every n_check epochs for model

# learning rates (**change/check if re-running based on latest html**)
lr_q = 1e-2 #default = 1e-2
lr_p = 1e-4  #default = 1e-4

# condition
condition = '_HeightsOnly'

# noise 
learn_noise = True
set_noise = True
if set_noise:
    voxel_noise = 0.3 # run the check_voxel_noise.ipynb before to get a sense
    print(f"setting voxel noise to {voxel_noise}")
else:        
    voxel_noise = 0.1 #default
          
if learn_noise:
    noise_opts = f'learned-{voxel_noise}' 
else:
    noise_opts = f'fixed-{voxel_noise}'
print(f'noise option: {noise_opts}')    

['100', '103', '104', '105', '106', '107', '108', '109', '111', '112', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '127', '128', '130', '131', '132', '134', '135', '136', '137', '138', '139', '140', '142', '143', '144', '145', '146', '149', '150', '151', '152', '153', '154', '157', '158', '159', '160', '161', '162', '163', '164', '165', '166', '167', '169', '170', '171', '172', '174', '175', '176', '177', '179', '181', '182', '183', '184', '185', '186']
total subs = 71
setting voxel noise to 0.8
noise option: fixed-0.8


In [5]:
torch.cuda.set_device(0)
torch.cuda.get_device_name(0)

'Tesla V100-SXM2-32GB'

In [6]:
# define database filename
if subs != 'All':
    #AVFP_FILE = 'data/AVFP_NTFA_memory_N' + str(n_files) + '_subsetN' + str(subs) + '_' + mask_type + 'mask.tar'
    AVFP_FILE = base_dir + f'data/downsampled_test/AVFP_NTFA_N{total_subs}_subsetN{subs}_{mask_type}mask{condition}.tar'
else: #including all subjects
    AVFP_FILE = base_dir + f'data/AVFP_NTFA_N{total_subs}_{subs}_{mask_type}mask{condition}.tar'
print('\nFetching database:',AVFP_FILE)

avfp_db = tardb.FmriTarDataset(AVFP_FILE)


Fetching database: /work/abslab/Yiyu/NTFA_AVFP/data/downsampled_test/AVFP_NTFA_N71_subsetN20_GMmask_HeightsOnly.tar


In [7]:
# do we want to hold out blocks to test generalization? If so, how many?
hold_out_data = True
n_per_subj = 2

In [8]:
# specify training and testing sample:

if hold_out_data:
    rng = np.random.default_rng(2022)
    test_blocks = []
    for p in avfp_db.subjects():
        sub_tasks = [b['task'] for b in avfp_db.blocks.values() if b['subject'] == p]
        idx = rng.choice(len(sub_tasks), n_per_subj, replace=False)
        for i in idx:
            test_blocks.extend([b['id'] for b in avfp_db.blocks.values() if (b['subject'] == p) & (b['task'] == sub_tasks[i])])
    test_blocks = np.sort(test_blocks).tolist()
    print('Excluding',len(test_blocks),'blocks from training\nIDs:',test_blocks)  
else:
    test_blocks = []
    print('Including all blocks in training')

Excluding 40 blocks from training
IDs: [2, 7, 13, 14, 24, 34, 43, 46, 49, 57, 60, 61, 81, 83, 84, 94, 102, 107, 109, 110, 123, 130, 142, 143, 144, 155, 156, 166, 168, 174, 187, 190, 195, 199, 204, 215, 217, 225, 228, 239]


# Run pass information to DeepTFA

In [9]:
# set up directory and filename for saving model
query_dir = f'models/ablation_comparison/AVFP_NTFA_sub-{subs}_epoch-{n_epoch}_factor-{n_factor}_mask-{mask_type}_{p_weight}{s_weight}{c_weight}_lin-{linear_opts}_noise-{noise_opts}/'

if not os.path.isdir(query_dir):
    os.makedirs(query_dir)
print("\nsaving model to: ", query_dir,'\n')




saving model to:  models/ablation_comparison/AVFP_NTFA_sub-20_epoch-1000_factor-100_mask-GM_111_lin-PSC_noise-fixed-0.8/ 



In [14]:

dtfa = DTFA.DeepTFA(avfp_db, num_factors=n_factor, linear_params=linear_opts, query_name=query_dir, voxel_noise=voxel_noise)

In [15]:
dtfa.generative.hyperparams.voxel_noise

tensor([0.8000])

In [16]:
print('number of voxels:',dtfa.num_voxels)

number of voxels: 191002


In [17]:
print('number of trials:',dtfa.num_blocks) # should be N subjects * 36 videos

number of trials: 240


In [18]:
print('number of unique tasks:',len(dtfa.tasks()))

number of unique tasks: 12


In [19]:
print('subjects analyzed:',dtfa.subjects())

subjects analyzed: [100, 103, 104, 105, 106, 107, 108, 109, 111, 112, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123]


In [20]:
print('subjects analyzed:',dtfa.subjects())

subjects analyzed: [100, 103, 104, 105, 106, 107, 108, 109, 111, 112, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123]


# Train the model!!!

In [21]:
# name of the most recent model checkpoint, if generated (prefix for .dtfa_model and .dtfa_guide)
checkpoint_files = glob.glob(query_dir + 'CHECK*.dtfa*')

if not tuning:
    if len(checkpoint_files) > 0:
        state_name = max(checkpoint_files, key=os.path.getctime).split('.dtfa')[0]
        n_checkpoint_epochs = int(state_name.split('Epoch')[1].split('.')[0])
        print('\nRestarting from most recent checkpoint:',state_name,'\n')
        dtfa.load_state(state_name)
    else:
        n_checkpoint_epochs=0
        print('\nNo checkpoint found — starting training\n')
    
# if starting tuning, load in the baseline model:
elif tuning:
    baseline_dir = query_dir.split('mask_')[0] + 'mask_111/'
    baseline_files = glob.glob(baseline_dir + 'CHECK*.dtfa*')
    if len(baseline_files) > 0:
        baseline_name = max(baseline_files, key=os.path.getctime).split('.dtfa')[0]
        n_checkpoint_epochs = int(baseline_name.split('Epoch')[1].split('.')[0])
        print('\nStarting from baseline model:',baseline_name,'\n')
        dtfa.load_state(baseline_name)
    else:
        raise Exception('No baseline model found')


No checkpoint found — starting training



In [None]:

# html_files = glob.glob(query_dir + '*train*.html')
# if len(html_files) > 0:
#     html_name = max(html_files, key=os.path.getctime)
#     f=open(html_name,'r')
#     html_data=f.read()
#     if 'reducing learning rate' in html_data:
#         lr_q = float(re.findall('learning rate of group 2 to \d+.\d+e-\d+',data)[-1].split('to ')[-1])
#         lr_p = float(re.findall('learning rate of group 3 to \d+.\d+e-\d+',data)[-1].split('to ')[-1])
#         print('\nUsing updated learning rate:',lr_q, lr_p)
#     else: print('\nUsing default learning rate:',lr_q, lr_p)
# else: 
#     print('\nStarting with default learning rate:',lr_q, lr_p) 


# it seems even if the learning rate didn't change in the training, there is still a jump in the loss whenever the training starts from the check point

In [25]:
@timeout_decorator.timeout(25200) # 7 hours
def train_ntfa():
    dtfa.train(num_steps=n_epoch-n_checkpoint_epochs, num_steps_exist=n_checkpoint_epochs,
               learning_rate={'q': lr_q, 'p': lr_p}, log_level=logging.INFO, num_particles=1,
               batch_size=128, use_cuda=True, checkpoint_steps=n_check, patience=50,
               blocks_filter=avfp_db.inference_filter_blocks(training=True, exclude_blocks=test_blocks),learn_voxel_noise=learn_noise
              )

In [None]:
try:
    print('training')
    train_ntfa()
except: pass

training


In [None]:
dtfa.generative.hyperparams.voxel_noise

# plot loss after training

In [None]:
#Load in all saved losses txt files
loss_files = np.sort(glob.glob(query_dir + '*_losses.txt'))
losses = []
for i in loss_files:
    losses.append(np.loadtxt(i)) 
losses = np.concatenate(losses)

In [None]:
# adpated from function in utils
def plot_losses(losses):
    epochs = range(losses.shape[0])
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, losses, '-', lw=3)
    plt.title('Free-energy / -ELBO change over training', fontsize=20)
    plt.xlabel('Epoch', fontsize=16)
    plt.ylabel('Free-energy / -ELBO (nats)', fontsize=16)
    plt.savefig(query_dir + 'losses.pdf')
    plt.show()

In [None]:
print('Loss function over',len(losses),'epochs\n')
plot_losses(losses)

In [None]:
query_dir

In [None]:
html_name = query_dir + 'NTFA_AVFP_train_Epoch' + str(len(losses)) + '.html'
%store html_name

In [None]:
# notebook_name = 'NTFA_train_ablation.ipynb'
# os.system('jupyter nbconvert ' + notebook_name + ' --to html --output ' + html_name)