### Replay composition by day
Analysis 1: (past no reward,previous reward,future)

In [1]:
%reload_ext autoreload
%autoreload 2

In [3]:
import os
import pickle
import spyglass as nd
import pandas as pd
import statsmodels.api as sm
# ignore datajoint+jupyter async warnings
import warnings
warnings.simplefilter('ignore', category=DeprecationWarning)
warnings.simplefilter('ignore', category=ResourceWarning)

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import logging
import multiprocessing

FORMAT = '%(asctime)s %(message)s'

logging.basicConfig(level='INFO', format=FORMAT, datefmt='%d-%b-%y %H:%M:%S')

from spyglass.common import (Session, IntervalList,IntervalPositionInfo,
                             LabMember, LabTeam, Raw, Session, Nwbfile,
                            Electrode,LFPBand,interval_list_intersect)
from spyglass.common.common_interval import _intersection

from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename
from spyglass.common.common_position import IntervalPositionInfo, IntervalPositionInfoSelection


# Here are the analysis tables specific to Shijie Gu
from spyglass.shijiegu.Analysis_SGU import (TrialChoice,
                                   TrialChoiceReplay,
                                   RippleTimes,
                                   DecodeResultsLinear,get_linearization_map,
                                   find_ripple_times,classify_ripples,classify_ripple_content)
from spyglass.shijiegu.PastFuture_Replay import (replay_in_categories,find_distinct_subset,proportion,
                                                 unravel_replay,count_replay_by_category,category_day)
from spyglass.shijiegu.decodeHelpers import runSessionNames

[2024-10-31 15:34:12,546][INFO]: Connecting shijiegu-alt@lmf-db.cin.ucsf.edu:3306
31-Oct-24 15:34:12 Connecting shijiegu-alt@lmf-db.cin.ucsf.edu:3306
[2024-10-31 15:34:12,591][INFO]: Connected shijiegu-alt@lmf-db.cin.ucsf.edu:3306
31-Oct-24 15:34:12 Connected shijiegu-alt@lmf-db.cin.ucsf.edu:3306


In [16]:
# the only cell to be edited

# 'molly20220415.nwb': run only session 4 6 8
# all Seq2 days
nwb_file_names = ['molly20220415.nwb','molly20220416.nwb',
                  'molly20220417.nwb','molly20220418.nwb',
                  'molly20220419.nwb','molly20220420.nwb']
all_epochs_flag=[[4,6,8],[],[],[],[],[]]
p_value=0.05/6


#nwb_file_names = ['lewis20240107.nwb','lewis20240108.nwb',
#                  'lewis20240109.nwb','lewis20240110.nwb',
#                  'lewis20240116.nwb','lewis20240118.nwb',
#                  'lewis20240119.nwb','lewis20240120.nwb']
nwb_file_names = ['lewis20240106.nwb','lewis20240107.nwb','lewis20240108.nwb','lewis20240109.nwb','lewis20240118.nwb']
encoding_set = '2Dheadspeed_above_4'
classifier_param_name = 'default_decoding_gpu_4armMaze'
all_epochs_flag=[[],[],[],[],[],[],[],[]]
p_value=0.05/4

In [17]:
# BY DAY
categories_H=['home','past','past_reward','future_H']
plot_categories_H=['home','past','past_reward','future',]
categories_O=['home','past','past_reward','current','future_O']
plot_categories_O=['home','past','past_reward','current','future']

In [18]:
def bootstrap_random_day(b,nwb_copy_file_name,epochs_to_run,categories_H,plot_categories_H):
    np.random.seed(b)
    categories_H_day=category_day(nwb_copy_file_name,epochs_to_run,
                                  categories_H,plot_categories_H,
                                  simulate_random_flag=True)

    return proportion(categories_H_day)

#### 1. subset trials analysis

In [19]:
for ni in range(len(nwb_file_names)):
    nwb_file_name=nwb_file_names[ni]
    nwb_copy_file_name = get_nwb_copy_filename(nwb_file_name)
    session_interval, position_interval = runSessionNames(nwb_copy_file_name)
    
    if len(all_epochs_flag[ni])==0:
        epochs_to_run = session_interval
    else:
        epochs_to_run = session_interval[ni]
    
    count_H_day=category_day(nwb_copy_file_name,
                              epochs_to_run,categories_H,plot_categories_H)
    #count_O_day=category_day(nwb_copy_file_name,
    #                          epochs_to_run,categories_O,plot_categories_O)
    
    prop_H_day,total_H_day=proportion(count_H_day)
    #prop_O_day,total_O_day=proportion(count_O_day)
    
    B=100
    #with multiprocessing.Pool(10) as p:
    #    results=p.map(bootstrap_random_day,[b for b in range(B)])
    results = []
    for b in range(B):
        print(b)
        results.append( bootstrap_random_day(b,nwb_copy_file_name,epochs_to_run,categories_H,plot_categories_H) )
    
    prop_H_boot=np.zeros((B,len(count_H_day)))
    #prop_O_boot=np.zeros((B,len(count_O_day)))

    for b in range(B):
        categories_names=list(results[b][0].keys())
        for ki in range(len(categories_names)):
            prop_H_boot[b,ki]=results[b][0][categories_names[ki]]

        #categories_names=list(results[b][1][0].keys())
        #for ki in range(len(categories_names)):
        #    prop_O_boot[b,ki]=results[b][1][0][categories_names[ki]]
    
    
    H_boot_mean=np.mean(prop_H_boot,axis=0)
    H_boot_CI=np.quantile(prop_H_boot,[p_value,1-p_value],axis=0)
    
    animal=nwb_file_name[:5]
    date=nwb_copy_file_name[5:-5]
    decoding_path=(DecodeResultsLinear &
                   {'nwb_file_name': nwb_copy_file_name}).fetch('posterior')[0]
    resultfolder=os.path.join(os.path.split(decoding_path)[0],'analysis')
    isExist = os.path.exists(resultfolder)
    if not isExist:
        os.makedirs(resultfolder)

    data={'categories_H_day_prop':prop_H_day,
          'categories_H_num':total_H_day,
          'H_boot_mean':H_boot_mean,
          'H_boot_CI':H_boot_CI,
         }

    with open(os.path.join(resultfolder,'replay_category_'+animal+'_'+date+'.p'), 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

number of valid trials: 28
number of valid trials: 43
number of valid trials: 22
number of valid trials: 27
0
number of valid trials: 28
number of valid trials: 43
number of valid trials: 22
number of valid trials: 27
1
number of valid trials: 28
number of valid trials: 43
number of valid trials: 22
number of valid trials: 27
2
number of valid trials: 28
number of valid trials: 43
number of valid trials: 22
number of valid trials: 27
3
number of valid trials: 28
number of valid trials: 43
number of valid trials: 22
number of valid trials: 27
4
number of valid trials: 28
number of valid trials: 43
number of valid trials: 22
number of valid trials: 27
5
number of valid trials: 28
number of valid trials: 43
number of valid trials: 22
number of valid trials: 27
6
number of valid trials: 28
number of valid trials: 43
number of valid trials: 22
number of valid trials: 27
7
number of valid trials: 28
number of valid trials: 43
number of valid trials: 22
number of valid trials: 27
8
number of 

In [15]:
nwb_copy_file_name

'lewis20240108.nwblewis20240109_.nwb'



In [269]:
pickle.load(open(os.path.join(resultfolder,'replay_category_'+animal+'_'+date+'.p'),"rb"))

{'categories_H_day_prop': {'home': 0.928921568627451,
  'past': 0.00980392156862745,
  'past_reward': 0.024509803921568627,
  'future': 0.03676470588235294},
 'categories_H_num': 408,
 'categories_O_day_prop': {'home': 0.055865921787709494,
  'past': 0.00558659217877095,
  'past_reward': 0.0148975791433892,
  'current': 0.9087523277467412,
  'future': 0.0148975791433892},
 'categories_O_num': 537,
 'O_boot_mean': array([0.04022813, 0.2159799 , 0.19078607, 0.34898956, 0.20401634]),
 'O_boot_std': array([0.00737292, 0.02316715, 0.01598242, 0.02195183, 0.03288566]),
 'H_boot_mean': array([0.75362365, 0.07331474, 0.09694438, 0.07611724]),
 'H_boot_std': array([0.02135901, 0.0164067 , 0.01245653, 0.01091013])}

#### 2. all trials with GLM

In [None]:
def GLM_replay(nwb_file_name,epochs,replay_location,categories,plot_categories):
    '''
    replay_location: 'replay_H' or 'replay_O'
    '''

    nwb_copy_file_name = get_nwb_copy_filename(nwb_file_name)

    betas={}
    
    y_day=[]
    x_day=[]
    
    for epoch_num in epochs:

        # get all replay
        key={'nwb_file_name':nwb_copy_file_name,'epoch':epoch_num}
        replay_df_subset=pd.DataFrame((TrialChoiceReplay & key).fetch1('choice_reward_replay'))

        # get behavior
        behavior_df=get_df_tally(nwb_file_name,epoch_num)

        # create regresser matrix
        if categories[0]=='home':
            y=np.zeros((5*len(replay_df_subset.index),1))
            x=np.zeros((5*len(replay_df_subset.index),len(categories)))
        else:
            y=np.zeros((4*len(replay_df_subset.index),1))
            x=np.zeros((4*len(replay_df_subset.index),len(categories)))

        for i in range(len(replay_df_subset.index)):
            t=replay_df_subset.index[i]

            replay=np.array(unravel_replay([list(replay_df_subset.loc[t,replay_location])])[0])
            if categories[0]=='home':
                for a in range(4):
                    arm=a+1
                    y[i*5+a+1,0]=np.sum(replay==arm)
                for c in range(1,len(categories)):
                    correponding_arm=behavior_df.loc[t,categories[c]]
                    if not np.isnan(correponding_arm):
                        x[i*5+int(correponding_arm),c]=1
                # home
                x[i*5,0]=1
                y[i*5,0]=np.sum(replay==0)
            else:
                for a in range(4):
                    arm=a+1
                    y[i*4+a,0]=np.sum(replay==arm)
                for c in range(len(categories)):
                    correponding_arm=behavior_df.loc[t,categories[c]]
                    if not np.isnan(correponding_arm):
                        x[i*4+int(correponding_arm)-1,c]=1
                        
        x_day.append(x)
        y_day.append(y)
    x_day=np.concatenate(x_day,axis=0)
    y_day=np.concatenate(y_day,axis=0)
       
    
    # do GLM
    x_ = sm.add_constant(x)
    glm_poisson = sm.GLM(y,x_,family=sm.families.Poisson())
    res = glm_poisson.fit()
    CI=res.conf_int(alpha=0.05)
    print(res.summary())
    
    for c in range(len(plot_categories)):
        betas[plot_categories[c]]=(np.exp(res.params[c+1]),np.exp(CI[c+1])) #plus one because of the constant term
    
    return betas

In [None]:
categories=['past','past_reward','current','future']
plot_categories=['past','past_reward','future','future t+1']
replay_location='replay_H'
betas_H=GLM_replay(nwb_file_name,epochs,replay_location,categories,plot_categories)

categories=['past','past_reward','current','future']
plot_categories=['past','past_reward','current','future']
replay_location='replay_O'
betas_O=GLM_replay(nwb_file_name,epochs,replay_location,categories,plot_categories)


epoch name 02_SeqSession1
epoch name 04_Seq2Session1
epoch name 06_Seq2Session2
epoch name 08_Seq2Session3
                 Generalized Linear Model Regression Results                  
Dep. Variable:                      y   No. Observations:                  304
Model:                            GLM   Df Residuals:                      299
Model Family:                 Poisson   Df Model:                            4
Link Function:                    Log   Scale:                          1.0000
Method:                          IRLS   Log-Likelihood:                -174.87
Date:                Sat, 09 Jul 2022   Deviance:                       216.23
Time:                        13:55:46   Pearson chi2:                     323.
No. Iterations:                     6   Pseudo R-squ. (CS):             0.1018
Covariance Type:            nonrobust                                         
                 coef    std err          z      P>|z|      [0.025      0.975]
------------------------

In [None]:
betas_H

{'past': (0.2518171081386721, array([0.08449668, 0.75046568])),
 'past_reward': (0.8322678663709008, array([0.40975283, 1.69045763])),
 'future': (1.6947690205170083, array([0.92057769, 3.12004306])),
 'future t+1': (1.8403140628336758, array([1.01900775, 3.32358202]))}

In [None]:
betas_O

{'past': (0.6835910288831625, array([0.17562188, 2.6608114 ])),
 'past_reward': (0.31981827412932257, array([0.12867794, 0.79488161])),
 'current': (52.769523322950114, array([ 23.97368396, 116.15330362])),
 'future': (1.2546815295040081, array([0.44761144, 3.51694705]))}

In [181]:
categories_O_day_hat

{'home': 7, 'past': 44, 'past_reward': 45, 'current': 54, 'future': 44}

In [188]:
categories_H_day_hat

{'home': 259, 'past': 29, 'past_reward': 39, 'future': 25}