In [1]:
import numpy as np
import random
import types
from datetime import datetime
import pandas as pd
from tqdm import tqdm
import numpy as np
from pybads.bads import BADS
from itertools import chain
from pyvbmc import VBMC
from functools import partial
import time
from joblib import Parallel, delayed, cpu_count
import os
import pickle
import scipy
import submitit
import runpy
from pyvbmc import priors
import importlib
from pprint import pprint

In [2]:
## Select and load model space
#
model_space='basic_modelspace'
model_module = importlib.import_module(f"models.{model_space}")
#
model_list, bounds_list, plausible_bounds_list, all_models, prior_shapes, default_values=model_module.main()
#
nstates, nactions, nfutures = (3,3,3)

Models loaded:
0. SASSS_Omega_standard
1. SASSS_Omega_splitBias
2. SASSS_Omega_splitBeta
3. SASonly


In [3]:
### Select dataset
compiled_datadir='compiled_data'
dataset_name='dev'

## Load data
data_file=os.path.join(compiled_datadir,dataset_name, 'dataset.pkl')
#
df_full=pd.read_pickle(data_file)

# Load metadata
metadata_file=os.path.join(compiled_datadir,dataset_name, 'metadata.pkl')
with open(metadata_file, 'rb') as file:
    metadata = pickle.load(file)


### 
fitToolbox='pybads' # pyvbmc or pybads (note that vbmc write an additional output structure for each model)

### Other fit options
parallel_job=False
parallel_mode='slurm' # may be 'joblib' for local parallelization or 'slurm' for cluster parallelization
parallel_maxjobs=420
robustMode=True
verbose=False

priors_constrainFit=True
# reset mode
reset_mode='all'

### Development options
devParticipants=[] # if not empty: will restrict to this list of subjects

### Set fitting options
badsOptions={
    'max_iter':50, 
    'max_fun_evals':200,
    'display': "off",
    'nrestart': 2,
    'maxrestart': 2,
}
vbmcPyBadsInitOptions={
    'max_iter':25,
    'max_fun_evals':250,
    'display': "off",
}
vbmcOptions={
    'max_iter':50,
    'max_fun_evals':250,
    'display': "off",
    'nrestart': 4,
    'maxrestart': 8,
}
if parallel_job==False:
    vbmcOptions['display']="iter"
    badsOptions['display']="iter"
    
    
### Make output folder
fit_folder='model_fits'
if fitToolbox=='pybads':
    fit_name = f'{model_space}_bads_it{badsOptions["max_iter"]}_funit{badsOptions["max_fun_evals"]}_nres{badsOptions["nrestart"]}_reset{reset_mode}'
elif fitToolbox=='pyvbmc':
    fit_name = f'{model_space}_vbmc_it{vbmcOptions["max_iter"]}_funit{vbmcOptions["max_fun_evals"]}_nres{vbmcOptions["nrestart"]}_reset{reset_mode}'
    
fit_folder=os.path.join(fit_folder,fit_name)
if not os.path.isdir(fit_folder):
    os.makedirs(fit_folder)


In [4]:
print('DATA TO BE FITTED')
pprint(metadata)

DATA TO BE FITTED
{'SAmap': [[1, 2], [0, 2], [0, 1]],
 'dataset_name': 'dev',
 'folder_compiled_data': 'compiled_data',
 'folder_logfiles': 'logfiles',
 'included_participants': ['1', '1bis'],
 'included_runs': ['1', '2', '3'],
 'included_sessions': ['1', '2']}


In [None]:
### fit all the data with all agents

# initialize the complete output structure
results_df=pd.DataFrame()

#
if len(devParticipants)>0:
    df_full = df_full.loc[np.isin(df_full['participant'],devParticipants),:]
    participants=devParticipants
else:
  participants=metadata['included_participants']
  
massive_jobstruct=[]
massive_jobs_id=[]
massive_jobs_count=0

if fitToolbox=='pybads':
    nrestart=badsOptions.pop('nrestart')
    maxrestart=badsOptions.pop('maxrestart')
else:
    nrestart=vbmcOptions.pop('nrestart')
    maxrestart=vbmcOptions.pop('maxrestart')
    
    
for m, model in enumerate(model_list):
            
    nparameters=1+np.max([max(model['parameter_mapping'][mapval]) for mapval in model['parameter_mapping'] if isinstance(model['parameter_mapping'][mapval],list) ])
    lb=np.zeros(nparameters)
    hb=np.zeros(nparameters)
    plb=np.zeros(nparameters)
    phb=np.zeros(nparameters)
    if priors_constrainFit:
      prior_array=[]
    else:
      prior_array=None
    for prm in model['parameter_mapping']:
        if isinstance(model['parameter_mapping'][prm], list):
          for param_ind in model['parameter_mapping'][prm]:
            lb[param_ind]=model['bounds_list'][prm][0]
            hb[param_ind]=model['bounds_list'][prm][1]
            plb[param_ind]=model['plausible_bounds_list'][prm][0]
            phb[param_ind]=model['plausible_bounds_list'][prm][1]
            if priors_constrainFit:
              if prior_shapes[prm]=='UniformBox':
                  prior_array.append(priors.UniformBox(lb[param_ind],hb[param_ind]))
              elif prior_shapes[prm]=='Trapezoidal':
                  prior_array.append(priors.Trapezoidal(lb[param_ind], plb[param_ind],  phb[param_ind],hb[param_ind]))
              elif prior_shapes[prm]=='SmoothBox':
                  prior_array.append(priors.SmoothBox(plb[param_ind], phb[param_ind],0.8))
            
    mappingParam=model['parameter_mapping']
    extraParam=model['parameter_preset']
    agent=model['agent']
    
    def pyBADSparallel_func(fitId):
      
      if parallel_job and parallel_mode=='slurm':
          print(submitit.JobEnvironment())
          
      vposteriors=[]
      
      df_subdata=df_full.loc[df_full['participant']==fitId,:].copy(deep=True)

      npoints=np.sum(df_subdata['state_choice']>=0)
      
      def single_fit_run(seed, model, fitId, mappingParam, prior_array, phb, plb, lb, hb,
                        fitToolbox, badsOptions, vbmcOptions, vbmcPyBadsInitOptions,
                        df_subdata):
          
          np.random.seed(seed)
          
          try:
              paramDim = len(phb)
              param0 = plb + np.random.uniform(size=(paramDim,)) * (phb - plb)

              agent = model['agent'](model['factors'], nstates, nactions, nfutures)
              agent.init(model['parameter_preset'])

              funforfit = partial(agent.fit,
                                  mappingParam=mappingParam,
                                  arrayS=df_subdata['state'].values,
                                  arrayA=df_subdata['action'].values,
                                  arraySnext=df_subdata['next_state'].values,
                                  arrayR=df_subdata['reward'].values,
                                  arrayType=df_subdata['trial_tag_bool'].values,
                                  arrayMissed=df_subdata['missed'].values,
                                  arrayPrediction=df_subdata['state_choice'].values-1,
                                  arraySplit=df_subdata['visit'].values-1,
                                  resets=df_subdata['newblock'].values,
                                  returnMemory=False,
                                  prior_array=None,
                                  default_values=model['default_values'])

              if fitToolbox == 'pybads':
                  bads = BADS(funforfit, param0, lb, hb, plb, phb, options=badsOptions)
                  optimize_result = bads.optimize()
                  fitted = np.array(optimize_result['x'])
                  nll = optimize_result['fval']
              else:
                  bads = BADS(funforfit, param0, lb, hb, plb, phb, options=vbmcPyBadsInitOptions)
                  optimize_result = bads.optimize()
                  vbmc = VBMC(funforfit, optimize_result['x'], lb, hb, plb, phb,
                              options=vbmcOptions, prior=prior_array)
                  vp, optimize_result = vbmc.optimize()
                  fitted = vp.moments()
                  nll = -optimize_result['elbo']

              return (fitted, nll)
          
          except Exception as e:
              print("An error with fitId:", fitId)
              print("An error occurred:", e)
              print("Type of exception:", type(e))
              return (np.full(len(phb), np.nan), np.nan)


      # --- Setup for Parallel Execution ---
      fitNLL = np.full((maxrestart,), np.nan)
      fittedParameters = np.full((maxrestart, nparameters), np.nan)

      seeds = np.random.randint(0, 1e6, size=maxrestart)

      results = Parallel(n_jobs=-1)(
          delayed(single_fit_run)(
              seed, model, fitId, mappingParam, prior_array,
              phb, plb, lb, hb,
              fitToolbox, badsOptions, vbmcOptions, vbmcPyBadsInitOptions,
              df_subdata
          ) for seed in seeds
      )

      # Collect valid results up to `nrestart`
      rep = 0
      for i, (fitted, nll) in enumerate(results):
          if not np.any(np.isnan(fitted)) and not np.isnan(nll):
              fittedParameters[rep, :] = fitted
              fitNLL[rep] = nll
              rep += 1
              if verbose:
                print(f'rep {rep}')
                print(fitNLL)
              if rep >= nrestart:
                  break

      # --- Final selection ---
      try:
          min_index = pd.Series(fitNLL).idxmin()
          bestNLL = fitNLL[min_index]
          agent = model['agent'](model['factors'], nstates, nactions, nfutures)
          agent.init(model['parameter_preset'])
          
          bestNLL, priorNLL, agentMemory, mappingX = agent.fit(
              fittedParameters[min_index, :],
              mappingParam=mappingParam,
              arrayS=df_subdata['state'].values,
              arrayA=df_subdata['action'].values,
              arraySnext=df_subdata['next_state'].values,
              arrayR=df_subdata['reward'].values,
              arrayType=df_subdata['trial_tag_bool'].values,
              arrayMissed=df_subdata['missed'].values,
              arrayPrediction=df_subdata['state_choice'].values-1,
              arraySplit=df_subdata['visit'].values-1,
              resets=np.where(df_subdata['newblock'].values==1)[0],
              returnMemory=True,
              prior_array=None,
              default_values=model['default_values'],
          )
          if verbose:
            print(fitId, bestNLL, bestAIC, bestBIC, priorNLL, fittedParameters[min_index,:], 100*(rep/maxrestart))

          bestAIC = 2*nparameters + 2*bestNLL
          bestBIC = np.log(npoints)*nparameters + 2*bestNLL
          return bestNLL, bestAIC, bestBIC, priorNLL, fittedParameters[min_index,:], 100*(rep/maxrestart),fitId, agentMemory, mappingX, vposteriors
      except Exception as e:
        print("Final fit selection failed:", e)
        print("Terminal error (presumably failed fit): ", e, flush=True)
        return np.nan, np.nan, np.nan, np.nan, fittedParameters[0,:], 100*rep/maxrestart,fitId, pd.DataFrame(), {}, vposteriors,
    
    if parallel_job==False:
        result_jobs=[]
        with tqdm(total=len(participants)) as pbar:
            for fitId in participants:
                resultjob=pyBADSparallel_func(fitId)
                result_jobs.append(resultjob)
                pbar.update(1)
        time.sleep(10)
    else:
        if parallel_mode=='joblib':
            result_jobs=Parallel(n_jobs=40)(delayed(pyBADSparallel_func)(fitId) for fitId in participants)
            batchdf = pd.DataFrame(list(result_jobs), columns=["bestNLL", "bestAIC", "bestBIC", "priorNLL", "fitParameters", "fitSuccess","fitId", "fitMemory", "mappingX", "VP"])     
            batchdf["fitAgent"]=model_list[m]['name']
            batchdf.to_pickle(os.path.join(fit_folder, model_list[m]['name'] + '.pkl'))
            with open(os.path.join(fit_folder, model_list[m]['name'] + '_info.pkl'), 'wb') as handle:
                pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)
                print('files saved', flush=True)
            results_df=pd.concat([results_df, batchdf])
        elif parallel_mode=='slurm':
            now = datetime.now()
            compact_str = now.strftime("%m%d%H%M")
            fit_logfolder=os.path.join(fit_folder,'slurm_logs', model['name'],compact_str)
            if not os.path.isdir(fit_logfolder):
                os.makedirs(fit_logfolder)
            executor = submitit.SlurmExecutor(folder=fit_logfolder, max_num_timeout=5)
            executor.update_parameters(mem=3000, time=7000, partition ="CPU", cpus_per_task=2, signal_delay_s=300)
            unordered_results=[]
            returning_order=[]
            nreturned=0
            modeljobs = []
            with tqdm(total=len(participants), desc=f"Submitting jobs {model['name']}...") as pbar:
                for fitId in participants:
                    submititjob = executor.submit(pyBADSparallel_func, fitId)
                    modeljobs.append(submititjob)
                    time.sleep(0.25)
                    pbar.update(1)
            time.sleep(1)
            modeljobs_id={job.job_id: idx for idx, job in enumerate(modeljobs)}
            massive_jobstruct.append(modeljobs)
            massive_jobs_id.append(modeljobs_id)
            massive_jobs_count+=len(modeljobs_id)
            print(f'finished computation of {model_list[m]["name"]}', flush=True)



  0%|          | 0/2 [00:00<?, ?it/s]