In [1]:
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages

import matplotlib.pyplot as plt

In [2]:
exp_df = pd.read_csv('outExp.csv')


In [3]:
# remove rows where success = 0 and abort_event = -1 or abort_Event = -2 
print(exp_df[(exp_df['success'] == 0) & (exp_df['abort_event'] == -2)]['RTwrtStim'].value_counts(dropna=False))
print(exp_df[(exp_df['success'] == 0) & (exp_df['abort_event'] == -1)]['RTwrtStim'].value_counts(dropna=False))
exp_df = exp_df[~((exp_df['success'] == 0) & (exp_df['abort_event'].isin([-1, -2])))]

RTwrtStim
NaN         23
1.121829     1
0.012729     1
0.523252     1
0.072325     1
Name: count, dtype: int64
RTwrtStim
NaN         20
0.178688     1
Name: count, dtype: int64


In [4]:
#### Only update where response_poke is NaN and success is 1 or -1
mask_nan = exp_df['response_poke'].isna()
mask_success_1 = (exp_df['success'] == 1)
mask_success_neg1 = (exp_df['success'] == -1)
mask_ild_pos = (exp_df['ILD'] > 0)
mask_ild_neg = (exp_df['ILD'] < 0)

# For success == 1
exp_df.loc[mask_nan & mask_success_1 & mask_ild_pos, 'response_poke'] = 3
exp_df.loc[mask_nan & mask_success_1 & mask_ild_neg, 'response_poke'] = 2

# For success == -1
exp_df.loc[mask_nan & mask_success_neg1 & mask_ild_pos, 'response_poke'] = 2
exp_df.loc[mask_nan & mask_success_neg1 & mask_ild_neg, 'response_poke'] = 3



#### Comparable, SD, LED1, LED2, LED34_odd
exp_df_selected_batches = exp_df[exp_df['batch_name'].isin(['Comparable', 'SD', 'LED1', 'LED2', 'LED34'])]

#### LED34_odd 
exp_df_selected_batches_1 = exp_df_selected_batches[
    ((exp_df_selected_batches["batch_name"] == "LED34") &
     (exp_df_selected_batches["animal"] % 2 == 1) &
     (exp_df_selected_batches["session_type"].isin([1, 2])))
    | (exp_df_selected_batches["batch_name"] != "LED34")
]

### LED_trial = 0 or nan
exp_df_led_off = exp_df_selected_batches_1[\
    exp_df_selected_batches_1['LED_trial'].isna() \
        | (exp_df_selected_batches_1['LED_trial'] == 0)].copy()



### Add choice and accuracy columns
exp_df_led_off.loc[:, 'choice'] =(2*(exp_df_led_off['response_poke']) - 5)
exp_df_led_off.loc[:, 'accuracy'] = (exp_df_led_off['choice'] * exp_df_led_off['ILD'] > 0).astype(int)

# Helper funcs

In [5]:
def plot_psycho(df):
    ILD_unique = np.sort(df['ILD'].unique())
    prob_choice1 = np.zeros(len(ILD_unique))
    for idx, ild in enumerate(ILD_unique):
        subset = df[df['ILD'] == ild]
        prob = np.mean(subset['choice'] == 1)
        prob_choice1[idx] = prob

    return ILD_unique, prob_choice1


def plot_tacho(df, bins):
    # prob of correct vs RT
    df.loc[:,'RT_bin'] = pd.cut(df['RTwrtStim'], bins=bins, include_lowest=True)
    grouped_by_rt_bin = df.groupby('RT_bin', observed=False)['accuracy'].agg(['mean', 'count'])
    grouped_by_rt_bin['bin_mid'] = grouped_by_rt_bin.index.map(lambda x: x.mid)
    return grouped_by_rt_bin['bin_mid'], grouped_by_rt_bin['mean']

In [6]:
# Define a custom color palette (extend or modify as needed)
custom_colors = [
    "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
    "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf",
    "#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00",
    "#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62",
    "#8da0cb", "#e78ac3", "#a6d854", "#ffd92f", "#e5c494"
]

all_animals = sorted(exp_df_led_off['animal'].unique())
if len(all_animals) > len(custom_colors):
    raise ValueError("Not enough custom colors for all animals. Please extend the color list.")

animal_to_color = {animal: custom_colors[i] for i, animal in enumerate(all_animals)}


pdf_path = "pdfs/batch_tstim_rt_psy_tach.pdf"
with PdfPages(pdf_path) as pdf:
    for batch in exp_df_led_off['batch_name'].unique():
        batch_df = exp_df_led_off[exp_df_led_off['batch_name'] == batch]
        session_types = batch_df['session_type'].dropna().unique()
        n_sessions = len(session_types)
        
        fig, axes = plt.subplots(n_sessions, 4, figsize=(20, 3*n_sessions), squeeze=False)
        fig.suptitle(f'Batch: {batch}', fontsize=16)
        
        for i, session in enumerate(session_types):
            session_df = batch_df[batch_df['session_type'] == session]
            ax_fix = axes[i, 0]
            ax_rt = axes[i, 1]
            ax_psycho = axes[i, 2]
            ax_tacho = axes[i, 3]
            
            for animal in session_df['animal'].unique():
                animal_df = session_df[session_df['animal'] == animal].copy()
                color = animal_to_color[animal]
                # Intended Fix
                ax_fix.hist(
                    animal_df['intended_fix'].dropna(),
                    bins=np.arange(0, 2, 0.02),
                    histtype='step',
                    color=color,
                    label=f'Animal {animal}' if i == 0 else None,
                    density=True
                )
                # RTwrtStim
                if 'RTwrtStim' in animal_df.columns:
                    ax_rt.hist(
                        animal_df['RTwrtStim'].dropna(),
                        bins=np.arange(-1, 2, 0.02),
                        histtype='step',
                        color=color,
                        label=f'Animal {animal}' if i == 0 else None,
                        density=True
                    )
                    ax_rt.set_xlim(-0.5, 1)
                
                animal_df_valid = animal_df[(animal_df['RTwrtStim'] > 0) & (animal_df['success'].isin([1,-1])) ].copy()
                # Psychometric curve
                if 'ILD' in animal_df.columns and 'choice' in animal_df.columns:
                    x_psycho, y_psycho = plot_psycho(animal_df_valid)
                    ax_psycho.scatter(x_psycho, y_psycho, color=color, label=f'Animal {animal}' if i == 0 else None)
                    ax_psycho.set_ylim(0, 1)
                
                # Tachometric curve
                if 'RTwrtStim' in animal_df.columns and 'accuracy' in animal_df.columns:
                    try:
                        x_tacho, y_tacho = plot_tacho(animal_df_valid, bins=np.arange(0, 2, 0.05))
                        ax_tacho.plot(x_tacho, y_tacho, color=color, label=f'Animal {animal}' if i == 0 else None)
                        ax_tacho.set_ylim(0.5, 1)
                        ax_tacho.set_xlim(0, 1)
                    except Exception as e:
                        print(f"plot_tacho failed for animal {animal} in session {session}: {e}")
            
            ax_fix.set_title(f'Session {session}: Intended Fix')
            ax_fix.set_xlabel('intended_fix')
            ax_fix.set_ylabel('Density')
            ax_rt.set_title(f'Session {session}: RTwrtStim')
            ax_rt.set_xlabel('RTwrtStim')
            ax_rt.set_ylabel('Density')
            ax_psycho.set_title(f'Session {session}: Psychometric')
            ax_psycho.set_xlabel('ILD')
            ax_psycho.set_ylabel('P(choice==1)')
            ax_psycho.set_ylim(0, 1)
            ax_tacho.set_title(f'Session {session}: Tachometric')
            ax_tacho.set_xlabel('RTwrtStim')
            ax_tacho.set_ylabel('P(correct)')
            ax_tacho.set_ylim(0.5, 1)
            if i == 0:
                ax_fix.legend()
                ax_rt.legend()
                ax_psycho.legend()
                ax_tacho.legend()
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf.savefig(fig)
        plt.close(fig)

print(f"PDF saved to {pdf_path}")

PDF saved to pdfs/batch_tstim_rt_psy_tach.pdf


# QQ plots

In [7]:
exp_df_led_off['batch_name'].unique()

array(['Comparable', 'SD', 'LED1', 'LED2', 'LED34'], dtype=object)

In [8]:
# Add abs_ILD column
exp_df_led_off['abs_ILD'] = np.abs(exp_df_led_off['ILD'])

pdf_path = "pdfs/qq_percentile_per_batch_session_animal_absILD.pdf"
percentiles = np.arange(5, 100, 10)

with PdfPages(pdf_path) as pdf:
    for batch in exp_df_led_off['batch_name'].unique():
        batch_df = exp_df_led_off[exp_df_led_off['batch_name'] == batch]
        for session in batch_df['session_type'].dropna().unique():
            session_df = batch_df[batch_df['session_type'] == session]
            animals = session_df['animal'].unique()
            abs_ILDs = np.sort(session_df['abs_ILD'].unique())
            n_animals = len(animals)
            n_abs_ILDs = len(abs_ILDs)
            fig, axes = plt.subplots(n_animals, n_abs_ILDs, figsize=(5*n_abs_ILDs, 3*n_animals), squeeze=False)
            fig.suptitle(f'Batch: {batch}, Session: {session}', fontsize=16)
            for i, animal in enumerate(animals):
                animal_df = session_df[session_df['animal'] == animal]
                for j, abs_ILD in enumerate(abs_ILDs):
                    ax = axes[i, j]
                    abs_ILD_df = animal_df[animal_df['abs_ILD'] == abs_ILD]
                    if abs_ILD_df.empty:
                        ax.set_visible(False)
                        continue
                    RTwrtStim_pos = abs_ILD_df[(abs_ILD_df['RTwrtStim'] > 0) & (abs_ILD_df['success'].isin([1,-1]))]
                    if RTwrtStim_pos.empty:
                        ax.set_visible(False)
                        continue
                    ABLs = np.sort(RTwrtStim_pos['ABL'].unique())
                    if len(ABLs) == 0:
                        ax.set_visible(False)
                        continue
                    # Compute percentiles for each ABL
                    q_dict = {}
                    for abl in ABLs:
                        q_dict[abl] = np.percentile(RTwrtStim_pos[RTwrtStim_pos['ABL'] == abl]['RTwrtStim'], percentiles)
                    abl_highest = ABLs.max()
                    Q_highest = q_dict[abl_highest]
                    # Plot for each ABL
                    for abl in ABLs:
                        if abl == abl_highest:
                            continue  # skip plotting diff for highest, will plot as x-axis
                        diff = q_dict[abl] - Q_highest
                        ax.plot(Q_highest, diff, marker='o', label=f'ABL {abl}')
                    # Plot x=y line for reference (optional)
                    ax.axhline(0, color='k', linestyle='--', linewidth=1)
                    ax.set_xlabel(f'Percentiles of RTwrtStim (ABL={abl_highest})')
                    ax.set_ylabel('Q_ABL - Q_highest_ABL')
                    ax.set_title(f'Animal: {animal}, abs(ILD): {abs_ILD}')
                    ax.legend(title='ABL')
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            pdf.savefig(fig)
            plt.close(fig)

print(f"PDF saved to {pdf_path}")

PDF saved to pdfs/qq_percentile_per_batch_session_animal_absILD.pdf
