In [24]:
import numpy as np
import numpy.random as npr
import pandas as pd
from sklearn import preprocessing

from notebooks.imports import *
from config import dir_config, main_config
from src.utils import pmf_utils, glm_hmm_utils
import pickle
import copy

In [5]:
compiled_dir = Path(dir_config.data.compiled)
processed_dir = Path(dir_config.data.processed)

In [6]:
session_metadata = pd.read_csv(compiled_dir / 'sessions_metadata.csv')

In [13]:
def extract_previous_data(trial_data, valid_idx, first_trial, n_trial_back=3):
	np.random.seed(1)

	# remove trials before the first valid trial or first three trials, whichever is larger
	prev_choice = np.empty((len(trial_data) - first_trial, n_trial_back), dtype=int)
	prev_target = np.empty((len(trial_data) - first_trial, n_trial_back), dtype=int)

	# Loop to populate prev_choice and prev_target with the last n_trial_back values for each trial
	for i in range(first_trial, len(trial_data)):
		# Get the valid indices for the last n_trial_back trials
		valid_indices = valid_idx[valid_idx < i][-n_trial_back:]  # Ensure we get the last n_trial_back valid trials

		prev_choice[i - first_trial] = trial_data.choice[valid_indices] * 2 - 1  # Convert choice to -1/1
		prev_target[i - first_trial] = trial_data.target[valid_indices] * 2 - 1  # Convert target to -1/1


	return prev_choice, prev_target

def prepare_input_data(data, input_dim, valid_idx, first_trial):
	if "no_bias" in _TRIALS:
		current_trial_param = 1
	else:
		current_trial_param = 2
	n_trial_back=(input_dim - current_trial_param) // 2
	X = np.ones((1, data.shape[0] - first_trial, input_dim))

	current_stimulus = data.coherence * (2 * data.target - 1)
	current_stimulus = current_stimulus / 100

	X[0, :, 0] = current_stimulus[first_trial:]  # current stimulus
	X[0, :, current_trial_param:current_trial_param+n_trial_back], X[0, :, current_trial_param+n_trial_back: current_trial_param+2*n_trial_back] = extract_previous_data(data, valid_idx, first_trial, n_trial_back=n_trial_back)
	return list(X)

# Remove Outliers
def is_outlier_session(stimulus, choices, mask, prob_toRF):
	mask = np.ones_like(choices, dtype=bool) if mask is None else mask
	prob_toRF = prob_toRF[mask]
	indices = np.where((prob_toRF != 50) & ~np.isnan(prob_toRF))[0]
	task_switch = indices[0] if len(indices) > 0 else 0

	equal_indices = np.where(mask)[0][:task_switch]
	unequal_indices = np.where(mask)[0][task_switch:]
	eq_data = {"signed_coherence": np.array(stimulus[equal_indices]) * 100, "choice": choices[equal_indices]}
	_, _, eq_model, _, _ = pmf_utils.get_psychometric_data(eq_data)
	uneq_data = {"signed_coherence": np.array(stimulus[unequal_indices]) * 100, "choice": choices[unequal_indices]}
	_, _, uneq_model, _, _ = pmf_utils.get_psychometric_data(uneq_data)

	mean_diff = eq_model.coefs_["mean"] - uneq_model.coefs_["mean"]
	var_diff = eq_model.coefs_["var"] - uneq_model.coefs_["var"]

	mean_threshold = main_config.rejection_criteria["mean_threshold"]
	var_threshold = main_config.rejection_criteria["var_threshold"]

	return mean_diff < mean_threshold and np.abs(var_diff) < var_threshold

In [25]:
reject_sessions = []

for _TRIALS in ["all_trials"]:
	n_states = 2  # number of discrete states
	obs_dim = 1  # number of observed dimensions: choice(toRF/awayRF)
	num_categories = 2  # number of categories for output

	if "no_bias" in _TRIALS:
		current_trial_param = 1
	else:
		current_trial_param = 2

	n_trial_back = 1

	input_dim = current_trial_param + 2*n_trial_back  # input dimensions: current signed coherence, 1(bias), previous choice(toRF/awayRF), previous target side(toRF/awayRF)

	# Pre-allocate lists for session data
	inputs_session_wise = []
	choices_session_wise = []
	invalid_idx_session_wise = []
	masks_session_wise = []
	GP_trial_num_session_wise = []
	prob_toRF_session_wise = []

	# Pre-build a mapping from session_id to prior_direction for efficient lookup
	prior_direction_map = session_metadata.set_index("session_id")["prior_direction"].to_dict()

	# Process each session
	for session_id in session_metadata["session_id"]:
		# Read trial data for each session
		trial_data = pd.read_csv(Path(compiled_dir, session_id, f"{session_id}_trial.csv"), index_col=None)
		GP_trial_data = trial_data[trial_data.task_type == 1].reset_index()

		if "eq_prior" in _TRIALS:
			GP_trial_data = GP_trial_data[GP_trial_data.prob_toRF == 50].reset_index()

		# Fill missing values for important columns
		GP_trial_data['choice'] = GP_trial_data.choice.fillna(-1)
		GP_trial_data['target'] = GP_trial_data.target.fillna(-1)
		GP_trial_data['outcome'] = GP_trial_data.outcome.fillna(-1)

		# Get valid indices based on outcomes
		valid_idx = np.where(GP_trial_data.outcome >= 0)[0]

		# First valid trial considering n_trial_back
		first_trial = valid_idx[n_trial_back - 1] + 1

		# Prepare inputs and choices
		inputs = prepare_input_data(GP_trial_data, input_dim, valid_idx, first_trial)
		choices = GP_trial_data.choice.values.reshape(-1, 1).astype("int")
		choices = choices[first_trial:]

		# Adjust invalid_idx and prepare mask
		invalid_idx = np.where(choices == -1)[0]

		if "all_trials" in _TRIALS:
			# For training, replace -1 with a random sample from 0,1
			choices[choices == -1] = np.random.choice(2, invalid_idx.shape[0])

			# Prepare mask
			mask = np.ones_like(choices, dtype=bool)
			mask[invalid_idx] = 0

			# Get trial numbers and prob_toRF for the cropped session
			GP_trial_num = np.array(GP_trial_data.trial_number)[first_trial:]
			prob_toRF = np.array(GP_trial_data.prob_toRF)[first_trial:]
		else:
			assert "all_trials" in _TRIALS, "Invalid trials option"

		# Check prior_direction for the current session and adjust inputs and choices
		prior_direction = prior_direction_map.get(session_id, 'awayRF')
		if prior_direction == 'awayRF':
			inputs[0][:, 0] = -inputs[0][:, 0]  # Flip the direction for input features
			inputs[0][:, 2:] = -inputs[0][:, 2:]
			choices = 1-choices  # Flip the choices

		assert len(choices) == len(inputs[0]), f"Length mismatch: {len(choices)} vs {len(inputs[0])}"
		assert len(mask) == len(inputs[0]), f"Length mismatch: {len(mask)} vs {len(inputs[0])}"
		assert len(GP_trial_num) == len(inputs[0]), f"Length mismatch: {len(GP_trial_num)} vs {len(inputs[0])}"
		assert len(prob_toRF) == len(inputs[0]), f"Length mismatch: {len(prob_toRF)} vs {len(inputs[0])}"

		if is_outlier_session(inputs[0][:, 0], choices, mask[:,0], prob_toRF):
			reject_sessions.append(session_id)

In [26]:
reject_sessions

['210216_GP_JP']

In [28]:
session_metadata = session_metadata[~session_metadata.session_id.isin(reject_sessions)]

session_metadata.to_csv(processed_dir / 'sessions_metadata.csv', index=False)