## Interpolation Test Stimuli Performance
Test set 1: Linear Interpolation. 

Two subjects are tested on linearly interpolated shuffle gaps. Here is what I need:
- Load in data and remove debug trials. 
- For each test trial, examine if the previous 64 training trials had over 80% accuracy. 
    - 4 training stimuli, means on average >13/16 of each stimuli
- Combined, individuals
    - Psychometric Functions

### Load data

In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from behav import plotting, utils, loading

import seaborn as sns
sns.set_style("whitegrid")

data_path = '/mnt/cube/RawData/Magpi/'

subjects = (
    ['B1520', 'B1535']
)

In [2]:
%%time
behav_data = loading.load_data_pandas(subjects, data_path)

CPU times: user 1.03 s, sys: 78 ms, total: 1.11 s
Wall time: 3.18 s


### Clean out debug trials before shaping

In [3]:
behav_data['B1520'] = behav_data['B1520'][behav_data['B1520'].index > '2023-01-19 00:00:00.000000'] 
## B1520 started trials on the 19th, previous are debug trials

In [4]:
behav_data['B1535'] = behav_data['B1535'][behav_data['B1535'].index > '2023-01-20 00:00:00.000000'] 
## B1535 started trials on the 20th, previous are debug trials

### For every test trial, only keep if the previous 64 trial accuracy is > 0.8

In [5]:
import pandas as pd
from tqdm.autonotebook import tqdm

  from tqdm.autonotebook import tqdm


In [6]:
max_trials = 100
accuracy_threshold = 0.8

In [None]:
%%time
test_data = {}
## for each subject
for subj in subjects:
    numbered_trials = behav_data[subj].reset_index()
    test_trials = numbered_trials[numbered_trials.type_ == 'test']
    valid_trials = pd.DataFrame(columns = test_trials.columns)
    ## iterative through each test trial
    for i, row in tqdm(test_trials.iterrows()):
        training_trials = []
        pointer = i
        ## while training trials are not fully collected
        while len(training_trials) < max_trials:
            ## go to previous trial to the pointer
            pointer = pointer - 1
            pointed_trial = numbered_trials.loc[pointer]
            ## if the pointer is on a normal trial, append to training_trials
            if pointed_trial.type_ == 'normal':
                training_trials.append(pointed_trial)
        
        ## check if training trials exceed criteria
        training_trials = pd.concat(training_trials)
        training_accuracy = np.mean(training_trials.correct)
        if training_accuracy > 0.8:
            valid_trials = pd.concat([valid_trials, pd.DataFrame([row], columns = test_trials.columns)], ignore_index = False)
    
    test_data[subj] = valid_trials

1007it [00:56, 18.05it/s]

In [None]:
## For every test trial, parse relevant info
for subj in subjects:
    
    stim_types = []
    pair_indices = []
    inter_nums = []

    for i, row in test_data[subj].iterrows():
        parsed = row.stimulus.split('_')
        stim_types.append(parsed[2])
        pair_indices.append(parsed[3])
        inter_nums.append(int(parsed[4].split('.')[0]))
        

    test_data[subj]['stim_type'] = stim_types
    test_data[subj]['pair_indices'] = pair_indices
    test_data[subj]['inter_nums'] = inter_nums

In [None]:
for subj in subjects:
    test_data[subj] = test_data[subj][test_data[subj]['inter_nums'] != 0]

In [None]:
test_data['B1520']

## Plot stimuli functions

In [None]:
from starling_rhythm.utils.paths import PROCESSED_DIR
import pandas as pd
import seaborn as sns
bID = "s_b1555_22"
SAVE_PATH = PROCESSED_DIR / bID / 'salvage_inter_tmf.pickle'

In [None]:
stims = pd.read_pickle(SAVE_PATH)
pair0_stims = stims[stims.pair_index == 0][::4]
pair1_stims = stims[stims.pair_index == 1][::4]

In [None]:
stims.head()

In [None]:
def normalize(x, newRange=(0, 1)): #x is an array. Default range is between zero and one
    xmin, xmax = np.min(x), np.max(x) #get max and min from input array
    norm = (x - xmin)/(xmax - xmin) # scale between zero and one
    
    if newRange == (0, 1):
        return(norm) # wanted range is the same as norm
    elif newRange != (0, 1):
        return norm * (newRange[1] - newRange[0]) + newRange[0] #scale to a different range.    
    #add other conditions here. For example, an error message

In [None]:
def linear_func(x, c, d):
    return c*x + d

def cubic_func(x, a, b, c, d):
    return a*x**3 + b*x**2 + c*x + d

def quad_func(x, b, c, d):
    return b*x**2 + c*x + d

def quartic_func(x, a, b, c, d, e):
    return a * x**4 + b * x**3 + c * x**2 + d * x + e

def logistic_4pm(x, A, K, B, M):
    return A + (K - A) / (1 + np.exp(-B * (x - M)))

In [None]:
## Test tmf model fit

In [None]:
from starling_rhythm.utils import logistic
from scipy.optimize import curve_fit

In [None]:
fig, axs = plt.subplots(3, sharex = True, figsize=(6, 6), dpi=300)
## plot linear function
input_frame = pair0_stims

axs[0].plot(input_frame['interpolation_num'], np.linspace(0.8, 0.2, len(input_frame)))
axs[0].set_title('Hypothesis 1: Perceive by Linear Interpolation')

## plot tmf
lower = 0.2
upper = 0.8
range_adjust = normalize(input_frame['mean_tMF'], newRange = (0.2, 0.8))
## plot quad function
popt, pcov = curve_fit(quad_func, input_frame['interpolation_num'], range_adjust)
new_y = quad_func(input_frame['interpolation_num'], *popt)
axs[1].plot(input_frame['interpolation_num'], new_y, label = 'Quadratic Fit')
axs[1].plot(input_frame['interpolation_num'], range_adjust, label = 'Scaled tMF')
axs[1].set_title('Hypothesis 2: Perceive by Multifractality')
axs[1].legend()

## plot logistic function
logistic_y = logistic_4pm(input_frame['interpolation_num'], 0.8, 0.2, 0.25, 64)
axs[2].plot(input_frame['interpolation_num'], logistic_y)
axs[2].set_title('Hypothesis 3: Perceive by Categorization')

fig.supylabel('Proportion of Left Response')
fig.supxlabel('Interpolation # (0 = MaxMF, 128 = MinMF)')

In [None]:
import numpy as np
import statsmodels.api as sm

In [None]:
X = sm.add_constant(input_frame['interpolation_num'])

order = []
aics = []
bics = []

for i in np.arange(0, 10):
    model = sm.OLS(range_adjust, X**i).fit()
    print(i)
    order.append(i)
    print('AIC for model: {:.3f}'.format(model.aic))
    aics.append(model.aic)
    print('BIC for model: {:.3f}'.format(model.bic))
    bics.append(model.bic)
    
tmf_models = pd.DataFrame(
    {
        'order': order,
        'aic': aics,
        'bic': bics
    }
)

In [None]:
plt.plot(tmf_models)

In [None]:
fig, axs = plt.subplots(2, sharex = True, figsize=(6, 4), dpi=300)
## plot linear function
input_frame = pair1_stims

axs[0].plot(input_frame['interpolation_num'], np.linspace(0.8, 0.2, len(input_frame)))
axs[0].set_title('Hypothesis 1: Perceive by Linear Interpolation')

## plot tmf
lower = 0.2
upper = 0.8
range_adjust = normalize(input_frame['mean_tMF'], newRange = (0.2, 0.8))
axs[1].plot(input_frame['interpolation_num'], range_adjust, label = 'Normalized tMF')
axs[1].set_title('Hypothesis 2: Perceive by Multifractality')

## plot cubic function
popt, pcov = curve_fit(cubic_func, input_frame['interpolation_num'], range_adjust)
new_y = cubic_func(input_frame['interpolation_num'], *popt)
axs[1].plot(input_frame['interpolation_num'], new_y, label = 'Cubic Fit')
plt.legend()

fig.supylabel('Proportion of Left Response')
fig.supxlabel('Interpolation # (0 = MaxMF, 128 = MinMF)')

In [None]:
X = sm.add_constant(input_frame['interpolation_num'])

order = []
aics = []
bics = []

for i in np.arange(0, 10):
    model = sm.OLS(range_adjust, X**i).fit()
    print(i)
    order.append(i)
    print('AIC for model: {:.3f}'.format(model.aic))
    aics.append(model.aic)
    print('BIC for model: {:.3f}'.format(model.bic))
    bics.append(model.bic)
    
tmf_models = pd.DataFrame(
    {
        'order': order,
        'aic': aics,
        'bic': bics
    }
)

In [None]:
plt.plot(tmf_models)

## Plot psychometric function

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
PMFX = {}
for subj in subjects:
    test_proportion = pd.DataFrame(test_data[subj].groupby(
            ['stim_type', 'pair_indices', 'inter_nums']
        )['response'].agg('value_counts', normalize = True))
    test_proportion = test_proportion.rename(columns = {'response': 'prop'})
    PMFX[subj] = test_proportion

### Fit psychometric function

## Plot

In [None]:
subject_list = []
stim_type_list = []
pair_index_list = []
response_type_list = []
parameter_list = []

## plotting
for subj in subjects:
    ## for each stim_type
    for stim_type in np.unique(test_data[subj].stim_type):
        ## for each pair index
        for pair_index in np.unique(test_data[subj].pair_indices):
            plt.figure(figsize = (16, 4))
            
            PMFX2 = PMFX[subj].reset_index()
            PMFX2 = PMFX2[PMFX2.response == 'left']
            PMFX2_specific = PMFX2[(PMFX2.pair_indices == pair_index) & (PMFX2.response == 'left')]
            
            ## plot empirical accuracies
            sns.scatterplot(
                data = PMFX2_specific, 
                x = 'inter_nums', y = 'prop', marker = "X", alpha = 1, color = 'salmon'
            )
            plt.ylim([0, 1])
            
            ## fit and plot psychometric functions
            
            for response_type in np.unique(PMFX2.response):
                if response_type == 'none':
                    continue
                #print("Generating psychometric curve for " + str(pair_index) + '_' + str(response_type))
                PMFX2_specific = PMFX2[(PMFX2.pair_indices == pair_index) & (PMFX2.response == response_type)]
                x = PMFX2_specific['inter_nums'].values
                y = PMFX2_specific['prop'].values
                assert len(x) == len(y)
                try:
                    solutions = logistic.fit_4pl(x, y, p_start = [0, 1, 0, 64])
                    parameter_list.append(solutions)
                    subject_list.append(subj)
                    stim_type_list.append(stim_type)
                    pair_index_list.append(pair_index)
                    response_type_list.append(response_type)
                    y_sig = logistic.four_param_logistic(solutions)(x)
                    line_col = 'green'
                    if response_type == 'left':
                        line_col = 'cornflowerblue'
                    if response_type == 'right':
                        line_col = 'orange'
                        
                    linear_popt, linear_popv = curve_fit(linear_func, x, y)
                    y_lin = linear_func(x, *linear_popt)
                    
                    quad_popt, quad_popv = curve_fit(quad_func, x, y)
                    y_quad = quad_func(x, *quad_popt)
                        
                    ## graph linear model
                    sns.lineplot(x, y_lin, color = line_col, linewidth = 3, alpha = 0.9, label = 'Linear')
                    ## graph quadratic model
                    sns.lineplot(x, y_quad, color = line_col, linewidth = 3, linestyle = 'dashed', alpha = 0.9, label = 'Quadratic')
                    ## graph sigmoidal 4-parameter logistic model
                    sns.lineplot(x, y_sig, color = line_col, linewidth = 3, linestyle = 'dotted', alpha = 0.9, label = 'Logistic')
                    plt.legend()
                    
                except Exception as e:
                    print(e)
            
            plt.xlabel('Interpolation # (0 = MaxMF, 128 = MinMF)')
            plt.ylabel('Proportion of Response')
            title = subj + '_' + stim_type + '_index:' + pair_index
            plt.title(title)
            plt.ylim([0, 1])
            
PMFX_results = pd.DataFrame(
    {
        "subject": subject_list,
        "stim_type": stim_type_list,
        "pair_index": pair_index_list,
        "response_type": response_type_list,
        "parameters": parameter_list
    }
)

In [None]:
subject_list = []
stim_type_list = []
pair_index_list = []
response_type_list = []
parameter_list = []

## plotting
for subj in subjects:
    ## for each stim_type
    for stim_type in np.unique(test_data[subj].stim_type):
        ## for each pair index
        for pair_index in np.unique(test_data[subj].pair_indices):
            plt.figure(figsize = (16, 4))
            
            PMFX2 = PMFX[subj].reset_index()
            PMFX2 = PMFX2[PMFX2.response == 'right']
            PMFX2_specific = PMFX2[(PMFX2.pair_indices == pair_index) & (PMFX2.response == 'right')]
            
            ## plot empirical accuracies
            sns.scatterplot(
                data = PMFX2_specific, 
                x = 'inter_nums', y = 'prop', marker = "X", alpha = 1, color = 'salmon'
            )
            plt.ylim([0, 1])
            
            ## fit and plot psychometric functions
            
            for response_type in np.unique(PMFX2.response):
                if response_type == 'none':
                    continue
                #print("Generating psychometric curve for " + str(pair_index) + '_' + str(response_type))
                PMFX2_specific = PMFX2[(PMFX2.pair_indices == pair_index) & (PMFX2.response == response_type)]
                x = PMFX2_specific['inter_nums'].values
                y = PMFX2_specific['prop'].values
                assert len(x) == len(y)
                try:
                    solutions = logistic.fit_4pl(x, y, p_start = [0, 1, 0, 64])
                    parameter_list.append(solutions)
                    subject_list.append(subj)
                    stim_type_list.append(stim_type)
                    pair_index_list.append(pair_index)
                    response_type_list.append(response_type)
                    y_sig = logistic.four_param_logistic(solutions)(x)
                    line_col = 'green'
                    if response_type == 'left':
                        line_col = 'cornflowerblue'
                    if response_type == 'right':
                        line_col = 'orange'
                        
                    linear_popt, linear_popv = curve_fit(linear_func, x, y)
                    y_lin = linear_func(x, *linear_popt)
                    
                    quad_popt, quad_popv = curve_fit(quad_func, x, y)
                    y_quad = quad_func(x, *quad_popt)
                        
                    ## graph linear model
                    sns.lineplot(x, y_lin, color = line_col, linewidth = 3, alpha = 0.9, label = 'Linear')
                    ## graph quadratic model
                    sns.lineplot(x, y_quad, color = line_col, linewidth = 3, linestyle = 'dashed', alpha = 0.9, label = 'Quadratic')
                    ## graph sigmoidal 4-parameter logistic model
                    sns.lineplot(x, y_sig, color = line_col, linewidth = 3, linestyle = 'dotted', alpha = 0.9, label = 'Logistic')
                    plt.legend()
                    
                except Exception as e:
                    print(e)
            
            plt.xlabel('Interpolation # (0 = MaxMF, 128 = MinMF)')
            plt.ylabel('Proportion of Response')
            title = subj + '_' + stim_type + '_index:' + pair_index
            plt.title(title)
            plt.ylim([0, 1])
            
PMFX_results = pd.DataFrame(
    {
        "subject": subject_list,
        "stim_type": stim_type_list,
        "pair_index": pair_index_list,
        "response_type": response_type_list,
        "parameters": parameter_list
    }
)

In [None]:
subject_list = []
stim_type_list = []
pair_index_list = []
response_type_list = []
parameter_list = []

## plotting
for subj in subjects:
    ## for each stim_type
    for stim_type in np.unique(test_data[subj].stim_type):
        ## for each pair index
        for pair_index in np.unique(test_data[subj].pair_indices):
            plt.figure(figsize = (16, 4))
            
            PMFX2 = PMFX[subj].reset_index()
            PMFX2 = PMFX2[(PMFX2.pair_indices == pair_index)]
            
            ## plot empirical accuracies
            sns.scatterplot(
                data = PMFX2, 
                x = 'inter_nums', y = 'prop', marker = "X", alpha = 1, hue = 'response'
            )
            plt.ylim([0, 1])
            
            ## fit and plot psychometric functions
            
            for response_type in np.unique(PMFX2.response):
                if response_type == 'none':
                    continue
                #print("Generating psychometric curve for " + str(pair_index) + '_' + str(response_type))
                PMFX2_specific = PMFX2[(PMFX2.pair_indices == pair_index) & (PMFX2.response == response_type)]
                x = PMFX2_specific['inter_nums'].values
                y = PMFX2_specific['prop'].values
                assert len(x) == len(y)
                try:
                    solutions = logistic.fit_4pl(x, y, p_start = [0, 1, 0, 64])
                    parameter_list.append(solutions)
                    subject_list.append(subj)
                    stim_type_list.append(stim_type)
                    pair_index_list.append(pair_index)
                    response_type_list.append(response_type)
                    y_sig = logistic.four_param_logistic(solutions)(x)
                    line_col = 'green'
                    if response_type == 'left':
                        line_col = 'cornflowerblue'
                    if response_type == 'right':
                        line_col = 'orange'
                        
                    linear_popt, linear_popv = curve_fit(linear_func, x, y)
                    y_lin = linear_func(x, *linear_popt)
                    
                    quad_popt, quad_popv = curve_fit(quad_func, x, y)
                    y_quad = quad_func(x, *quad_popt)
                        
                    ## graph linear model
                    sns.lineplot(x, y_lin, color = line_col, linewidth = 3, alpha = 0.9, label = 'Linear')
                    ## graph quadratic model
                    sns.lineplot(x, y_quad, color = line_col, linewidth = 3, linestyle = 'dashed', alpha = 0.9, label = 'Quadratic')
                    ## graph sigmoidal 4-parameter logistic model
                    sns.lineplot(x, y_sig, color = line_col, linewidth = 3, linestyle = 'dotted', alpha = 0.9, label = 'Logistic')
                    plt.legend()
                    
                except Exception as e:
                    print(e)
            
            plt.xlabel('Interpolation # (0 = MaxMF, 128 = MinMF)')
            plt.ylabel('Proportion of Response')
            title = subj + '_' + stim_type + '_index:' + pair_index
            plt.title(title)
            plt.ylim([0, 1])
            
PMFX_results = pd.DataFrame(
    {
        "subject": subject_list,
        "stim_type": stim_type_list,
        "pair_index": pair_index_list,
        "response_type": response_type_list,
        "parameters": parameter_list
    }
)

In [None]:
PMFX_results

## Model Comparisons

In [None]:
import statsmodels.formula.api as smf

In [None]:
import numpy as np
import pandas as pd
import statsmodels.api as sm

def calc_aic_bic_manual(nparams, func, p, y, x):
    """
    Calculate the AIC and BIC for a given model and data.

    Parameters:
        model (statsmodels.api object): the fitted model object
        y (array-like): the dependent variable
        x (array-like): the independent variable(s)

    Returns:
        AIC (float): the Akaike Information Criterion
        BIC (float): the Bayesian Information Criterion
    """
    # calculate the number of observations and parameters in the model
    nobs = len(y)
    #nparams = model.params.shape[0]

    # calculate the residual sum of squares
    resid = y - func(x, *p)
    rss = np.sum(resid**2)

    # calculate the maximum log-likelihood
    llf = -0.5 * nobs * (np.log(2*np.pi) + np.log(rss/nobs) + 1)

    # calculate the AIC and BIC
    aic = 2 * nparams - 2 * llf
    bic = np.log(nobs) * nparams - 2 * llf

    return aic, bic

In [None]:
subject_list = []
stim_type_list = []
pair_index_list = []
response_type_list = []

## plotting
for subj in subjects:
    print(subj)
    ## for each stim_type
    for stim_type in np.unique(test_data[subj].stim_type):
        ## for each pair index
        for pair_index in np.unique(test_data[subj].pair_indices):
            
            print(pair_index)
            response_type = 'left'
            if response_type == 'none':
                continue
            if response_type == 'right':
                continue
            
            PMFX2 = PMFX[subj].reset_index()
            PMFX2_specific = PMFX2[(PMFX2.pair_indices == pair_index) & (PMFX2.response == response_type)]
            
            
            x = PMFX2_specific['inter_nums'].values
            y = PMFX2_specific['prop'].values
            
            # create and fit linear model
            X_linear = sm.add_constant(x)
            linear_model = sm.OLS(y, X_linear).fit()

            # create and fit quadratic model
            X_quadratic = sm.add_constant(np.column_stack((x, x**2)))
            quadratic_model = sm.OLS(y, X_quadratic).fit()    
            
            # generate predictions for each model
            linear_pred = linear_model.predict(X_linear)
            
            quadratic_pred = quadratic_model.predict(X_quadratic)
            plt.plot(quadratic_pred)
            
            ## generate aic and bic for 4p_log
            imported_parameters = PMFX_results[
                (PMFX_results.subject == subj) &
                (PMFX_results.pair_index == pair_index) & 
                (PMFX_results.response_type == response_type)
            ]
            p = imported_parameters.parameters.values[0]
            log_aic, log_bic = calc_aic_bic_manual(4, logistic_4pm, p, y, x)
            
            
            
            # aic, bic
            print('lin')
            print(linear_model.aic)
            #print(linear_model.bic)
            print('quad')
            print(quadratic_model.aic)
            #print(quadratic_model.bic)
            print('log')
            print(log_aic)
            #print(log_bic)

## Sample size

In [None]:
sample_size = {}
for subj in subjects:
    test_count = pd.DataFrame(test_data[subj].groupby(
            ['stim_type', 'pair_indices', 'inter_nums']
        )['response'].agg('count'))
    sample_size[subj] = test_count

In [None]:
sample_size['B1520']

In [None]:
## plotting
for subj in subjects:
    ## for each stim_type
    for stim_type in np.unique(test_data[subj].stim_type):
        ## for each pair index
        for pair_index in np.unique(test_data[subj].pair_indices):
            plt.figure(figsize = (16, 4))
            sns.lineplot(
                data = sample_size[subj].loc[stim_type, pair_index], 
                x = 'inter_nums',
                y = 'response',
                linewidth = 3
            )
            plt.xlabel('Interpolation # (0 = MaxMF, 128 = MinMF)')
            title = subj + '_' + stim_type + '_index:' + pair_index
            plt.title(title)