In [5]:
import statsmodels.api as sm
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
from statsmodels.stats.anova import AnovaRM
from statsmodels.graphics.factorplots import interaction_plot
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import os
import pandas as pd
import numpy as np
import seaborn as sns
from scipy import stats
from scipy import special
import json 
from sklearn.linear_model import LogisticRegression
from scipy.optimize import curve_fit
#Import all needed libraries
from matplotlib.lines import Line2D
from statsmodels.genmod.bayes_mixed_glm import BinomialBayesMixedGLM
from matplotlib.backends.backend_pdf import PdfPages
# from statannot import add_stat_annotation

utilities = 'G:/My Drive/WORKING_MEMORY/PAPER/WM_manuscript_FIGURES/'
os.chdir(utilities)
# import functions as plots

data_path = 'G:/My Drive/WORKING_MEMORY/PAPER/2ND_SUBMISSION_NAT_NEURO/figures_for_resubmission/'
save_path = 'G:/My Drive/WORKING_MEMORY/PAPER/WM_manuscript_FIGURES/Fig. 2 Model/'
path = 'G:/My Drive/WORKING_MEMORY/PAPER/ANALYSIS_figures/'
os.chdir(path)

In [6]:
def compute_window_centered(data, runningwindow,option):
    """
    Computes a rolling average with a length of runningwindow samples.
    """
    performance = []
    start_on=False
    for i in range(len(data)):
        if data['trial'].iloc[i] <= int(runningwindow/2):
            # Store the first index of that session for the first initial trials
            if start_on == False:
                start=i
                start_on=True
            performance.append(round(np.mean(data[option].iloc[start:i + int(runningwindow/2)]), 2))
        elif i < (len(data)-runningwindow):
            if data['trial'].iloc[i] > data['trial'].iloc[i+runningwindow]:
                # Store the last values for the end of the session
                if end == True:
                    end_value = i+runningwindow-1
                    end = False
                performance.append(round(np.mean(data[option].iloc[i:end_value]), 2))
                
            else: # Rest of the session
                start_on=False
                end = True
                performance.append(round(np.mean(data[option].iloc[i - int(runningwindow/2):i+int(runningwindow/2)]), 2))
            
        else:
            performance.append(round(np.mean(data[option].iloc[i:len(data)]), 2))
    return performance

### Summary of the Notebook

This notebook contains the analysis and visualization of working memory data. The following key steps and variables are present:

1. **Imports and Setup**:
    - Various libraries such as `statsmodels`, `matplotlib`, `pandas`, `numpy`, and `seaborn` are imported.
    - Paths for data and saving results are defined.

2. **Functions**:
    - `compute_window_centered`: Computes a rolling average with a specified window size.

3. **Data Loading and Preparation**:
    - Data is loaded from CSV files into dataframes `df` and `df_model`.
    - Additional columns such as `WM_roll` and `state` are computed and added to the dataframe.

4. **Visualization**:
    - Bar plots and scatter plots are created to visualize the data.
    - Specific plots include:
      - Bar plot of `LL/trial` by subject.
      - Scatter plot and regression plot of `stim` vs. `substracted`.
      - Bar plots of `hit` and `LL/trial` by model and `notHMM`.

5. **Statistical Analysis**:
    - Linear regression is performed using `statsmodels` to analyze the relationship between `stim` and `substracted`.

6. **Additional Analysis**:
    - Analysis of lapse accuracy and plotting of accuracy by delay.
    - Violin and strip plots of `LL/trial` by model.

### Key Variables

- **Dataframes**:
  - `df`: Main dataframe containing working memory data.
  - `df_model`: Dataframe filtered for specific conditions.
  - `full_fit`: Dataframe containing model fit results.

- **Paths**:
  - `data_path`: Path to the data files.
  - `save_path`: Path to save the results.
  - `path`: Path for analysis figures.
  - `utilities`: Path for manuscript figures.

- **Other Variables**:
  - `threshold`: Threshold value for defining states.
  - `groupings`: List of groupings used in the analysis.
  - `fig`, `ax`: Matplotlib figure and axes objects for plotting.
  - `file`, `file_name`: Filenames used for loading data.

#### **Recover all the data that we need**

In [8]:
threshold = 0.5
groupings=['subject','delays','state']

file = 'all_data_HMM'
df = pd.read_csv(save_path+ file+'.csv', index_col=0)
df['WM_roll'] = compute_window_centered(df, 3,'WM')
df['state'] = np.where(df.WM_roll > threshold, 1, 0)

file = 'all_data_HMM_model'
df_model = pd.read_csv(save_path+ file+'.csv', index_col=0)

df_model = df_model.loc[df_model.animal_delay == 10]
df = df.loc[df.animal_delay == 10]

file_name = 'pertrialLL'
full_fit = pd.read_csv(save_path+file_name+'.csv', index_col=0)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(10,5))
sns.barplot(x='subject', y='LL/trial', data=full_fit.loc[full_fit.model == '11'], errorbar=('ci', 68), hue='substracted', palette='viridis')
plt.xticks(rotation=45)

#### **Test with the inferred states of te model**

In [78]:
new_df = df.loc[(df["WM"]>0.6)&(df['delays']==1)].groupby('subject')["stim"].count()/df.loc[(df['delays']==1)].groupby('subject')["stim"].count()
new_df = new_df.reset_index()

merge_df = pd.merge(new_df, full_fit, on='subject')
merge_df = merge_df.loc[merge_df["stim"]>0.6]
merge_df['notHMM'] =np.where(merge_df['substracted'] <0, True, False)

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
axes = ax[0][0]
sns.scatterplot(data=merge_df.loc[merge_df.model == '11'], x='stim', y='substracted', palette='viridis', hue='notHMM', ax=axes)
sns.regplot(data=merge_df.loc[merge_df.model =='11'], x='stim', y='substracted', scatter=False, color='black', ax=axes)

axes = ax[0][1]
sns.barplot(x ='stim', y='stim',  data=merge_df.loc[merge_df.model == '11'], hue='notHMM', palette='viridis', dodge=False, ax=axes)

axes = ax[1][0]
sns.barplot(x ='LL/trial', y='LL/trial',  data=merge_df.loc[merge_df.model == '11'], hue='notHMM', palette='viridis', dodge=False, ax=axes)
axes.set_title('DW-B')

axes = ax[1][1]
sns.barplot(x ='LL/trial', y='LL/trial',  data=merge_df.loc[merge_df.model == 'all'], palette='viridis', dodge=False, ax=axes)
axes.set_title('HMM')

In [None]:
import statsmodels.api as sm

# Extract the relevant data
data = merge_df.loc[merge_df.model == '11']
X = data['stim']
y = data['substracted']

# Add a constant to the independent variable
X = sm.add_constant(X)

# Fit the regression model
model = sm.OLS(y, X).fit()

# Print the regression results
print(model.summary())

# Extract the p-value and R-squared value
p_value = model.pvalues[1]  # p-value for the 'stim' coefficient
r_squared = model.rsquared

print(f'P-value: {p_value}')
print(f'R-squared: {r_squared}')

#### **Test using the lapse accuracy**

In [23]:
new_df = df.loc[(df['delays']==0.1)].groupby('subject').hit.mean()
new_df = new_df.reset_index()

merge_df = pd.merge(new_df, full_fit, on='subject')
merge_df['notHMM'] =np.where(merge_df['substracted'] <0, True, False)

# merge_df = merge_df.loc[merge_df["hit"]>0.6]

In [None]:
fig, ax = plt.subplots(1,1, figsize=(4,4))
sns.barplot(x ='hit', y='hit',  data=merge_df.loc[merge_df.model == '11'], hue='notHMM', palette='viridis', dodge=False, ax=ax)
plt.xlabel('Mice')
plt.ylabel('Accuracy at Delay 0s')
plt.xticks([])
# plt.legend(bbox_to_anchor=(1,1))
plt.legend(bbox_to_anchor=(1.4,1), title='DW-B better', loc='upper right')
sns.despine()
plt.savefig(save_path+'/3.6. mice_alternative_model_better.png', dpi=300, bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
axes = ax[0][0]
sns.scatterplot(data=merge_df.loc[merge_df.model == '11'], x='hit', y='substracted', palette='viridis', hue='notHMM', ax=axes)
sns.regplot(data=merge_df.loc[merge_df.model =='11'], x='hit', y='substracted', scatter=False, color='black', ax=axes)

axes = ax[0][1]
sns.barplot(x ='hit', y='hit',  data=merge_df.loc[merge_df.model == '11'], hue='notHMM', palette='viridis', dodge=False, ax=axes)

axes = ax[1][0]
sns.barplot(x ='LL/trial', y='LL/trial',  data=merge_df.loc[merge_df.model == '11'], hue='notHMM', palette='viridis', dodge=False, ax=axes)
axes.set_title('DW-B')

axes = ax[1][1]
sns.barplot(x ='LL/trial', y='LL/trial',  data=merge_df.loc[merge_df.model == 'all'], palette='viridis', dodge=False, ax=axes)
axes.set_title('HMM')

### **Where does the outlier in accuracy land?**

In [None]:
# df.loc[df.subject == 'E13_10'].groupby('delays').hit.mean()

sns.lineplot(x='delays', y='hit', data=df, marker='o', hue='subject', legend=False, errorbar=None, palette='Greys', alpha=0.3)
sns.lineplot(x='delays', y='hit', data=df.loc[df.subject == 'E04_10'], marker='o', color='red')
plt.ylim(0.6,1)

### **Plot using the LL/trial**

In [None]:
full_fit

In [None]:
fig, panel = plt.subplots(1,1, figsize=(8,4))
color_list = ['darkgrey', 'midnightblue','lightgrey','steelblue', 'black']
xA = np.random.normal(0, 0.05, len(full_fit))
sns.stripplot(x='model',y='LL/trial', data=full_fit, jitter=0.3,size=2, order=["all",'12','9','10','11'], hue='model', palette = color_list, edgecolor='white', linewidth=0.1, s=5,ax=panel)
sns.violinplot(x='model',y='LL/trial', data=full_fit, saturation=0.7, order=["all",'12','9','10','11'], hue='model', palette =  color_list,linewidth=0, width = 0.5, ax=panel)
sns.violinplot(x='model',y='LL/trial', data=full_fit, order=["all",'12','9','10','11'], hue='model', palette =  color_list,linewidth=1.5, width = 0.5, ax=panel )

panel.hlines(y=np.mean(full_fit.loc[full_fit.model =='all']['LL/trial'].mean()), xmin=-0.5, xmax=4.5, linestyle=':', color='black')
panel.set_xlabel('')
panel.set_ylabel('LL (bits/trial)')
# panel.set_ylim(-200,700)
labels = ['HMM', 'DW','DW-L','DW-M','DW-B']
panel.set_xticklabels(labels)
sns.despine()
plt.savefig(save_path+'/3.5. LL_per_trial.svg', dpi=300, bbox_inches='tight')