In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set()
plt.style.use('default')
plt.style.use("seaborn-poster")
from combined_phot_funcs import *
from hexLevelAnalyses import *
from photometryQuantifications import *
from valPropRats import *
%matplotlib qt

In [3]:
loadpath = "/Volumes/Tim/Photometry/10MfRatDataSet/"
datName = "photLevelDf"
df = pd.read_csv(loadpath+datName+".csv")
photrats = PhotRats(None)
photrats.df = reduce_mem_usage(df)
del df
photrats.directory_prefix = loadpath

# recompute nom_rwd_chosen

In [6]:
photrats.factor = "nom_rwd"
photrats.get_vals_byChosenEtc(chosen_only=True)

In [10]:
create_triframe(photrats)

In [11]:
photrats.get_portQvals(qtype="port",level="rat")

In [12]:
create_hexDf_startEndRepeat(photrats)

converting hexlabels to hex + direction (pairedHexState)


In [18]:
np.shape(photrats.hexDf.session.unique())

(82,)

In [16]:
test = photrats.hexDf.loc[photrats.hexDf.session==37,['hexlabel','port']].copy()

In [21]:
portInds = photrats.hexDf.loc[photrats.hexDf.port!=-100].index
ports2drop = portInds.values[np.concatenate([[False],np.diff(portInds)==1])]
photrats.hexDf.drop(ports2drop,axis=0,inplace=True)

In [22]:
photrats.hexDf.to_csv(photrats.directory_prefix+"hexLevelDf_cornerHexCorrectAlignment.csv")

# DA vs vel x cor. Session wide and during running bouts.

In [50]:
from statsmodels.tsa.stattools import ccf

In [74]:
#for each session, run x-cor between DA and velocity
window = [-int(photrats.fs*5)+1,int(photrats.fs*5)]
xcors_rat = []
for rat in photrats.df.rat.unique():
    xcors = []
    for s in photrats.df.loc[photrats.df.rat==rat,"session"].unique():
        dat = photrats.df.loc[photrats.df.session==s,].copy()
        seshVel = dat.vel.interpolate().values[dat.vel.interpolate().notnull()]
        seshDA = dat.green_z_scored.values[dat.vel.interpolate().notnull()]
        backwards = ccf(seshDA[:window[1]], seshVel[:window[1]], adjusted=False)[::-1]
        forwards = ccf(seshVel[:window[1]], seshDA[:window[1]], adjusted=False)
        xcors.append(np.r_[backwards[:-1], forwards])
    xcors_rat.append(np.mean(xcors,axis=0))

In [72]:
fig = plt.figure()
xvals = np.arange(window[0],window[1])/photrats.fs
plt.plot(xvals,np.mean(xcors,axis=0))
plt.fill_between(xvals,np.mean(xcors,axis=0)+sem(xcors),np.mean(xcors,axis=0)-sem(xcors),alpha=.5)
plt.axvline(x=0,ls='--',color='k',lw=2)
plt.axhline(y=0,ls='--',color='k',lw=2)
plt.xlabel("Lag relative to speed (s)")
plt.ylabel("Correlation with DA")
plt.xticks(np.arange(-5,6))
plt.tight_layout()
fig.savefig(photrats.directory_prefix+"velVsDAxcor_avgOverSessions.pdf")

In [75]:
fig = plt.figure()
xvals = np.arange(window[0],window[1])/photrats.fs
plt.plot(xvals,np.mean(xcors_rat,axis=0))
plt.fill_between(xvals,np.mean(xcors_rat,axis=0)+sem(xcors_rat),np.mean(xcors_rat,axis=0)-sem(xcors_rat),alpha=.5)
plt.axvline(x=0,ls='--',color='k',lw=2)
plt.axhline(y=0,ls='--',color='k',lw=2)
plt.xlabel("Lag relative to speed (s)")
plt.ylabel("Correlation with DA")
plt.xticks(np.arange(-5,6))
plt.tight_layout()
fig.savefig(photrats.directory_prefix+"velVsDAxcor_avgOverRats.pdf")

In [48]:
import sys
def sizeof_fmt(num, suffix='B'):
    ''' by Fred Cirera,  https://stackoverflow.com/a/1094933/1870254, modified'''
    for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
        if abs(num) < 1024.0:
            return "%3.1f %s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f %s%s" % (num, 'Yi', suffix)

for name, size in sorted(((name, sys.getsizeof(value)) for name, value in locals().items()),
                         key= lambda x: -x[1])[:10]:
    print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))


                paired_tmatrix: 248.2 KiB
                   paired_tmat: 124.1 KiB
                     tmatrix50: 115.0 KiB
                       tmatrix: 112.7 KiB
                        phexdf: 16.4 KiB
                          data: 13.5 KiB
                       arrowdf: 10.3 KiB
                            _5:  5.8 KiB
                            _6:  5.8 KiB
        pairedHex_adjacentDict:  4.6 KiB


## plot individual session dLight, reference, and velocity traces

In [19]:
plt.figure()
plt.hist(photrats.df.loc[photrats.df.session==25,"green_z_scored"].values,bins=100)
plt.tight_layout()

In [202]:
photrats.df.loc[photrats.df.session.diff()!=0,["rat","date","session","session_type"]].values

array([['IM-1272', 12082020, 2, 'prob'],
       ['IM-1272', 12092020, 3, 'prob'],
       ['IM-1272', 12102020, 4, 'prob'],
       ['IM-1273', 12092020, 12, 'prob'],
       ['IM-1273', 12152020, 15, 'prob'],
       ['IM-1276', 12042020, 23, 'prob'],
       ['IM-1276', 12072020, 24, 'prob'],
       ['IM-1276', 12112020, 25, 'prob'],
       ['IM-1276', 12152020, 26, 'prob'],
       ['IM-1276', 12172020, 27, 'prob'],
       ['IM-1291', 12092020, 39, 'prob'],
       ['IM-1291', 12102020, 40, 'prob'],
       ['IM-1291', 12082020, 38, 'prob'],
       ['IM-1291', 12152020, 41, 'prob'],
       ['IM-1291', 12162020, 42, 'prob'],
       ['IM-1292', 12142020, 52, 'prob'],
       ['IM-1292', 12152020, 53, 'prob'],
       ['IM-1292', 12162020, 54, 'prob'],
       ['IM-1272', 8182020, 5, 'barrier'],
       ['IM-1272', 8242020, 6, 'barrier'],
       ['IM-1272', 8282020, 7, 'barrier'],
       ['IM-1272', 10162020, 0, 'barrier'],
       ['IM-1272', 11032020, 1, 'barrier'],
       ['IM-1273', 8192020, 17

In [238]:
photrats.directory_prefix = '/Volumes/Tim-1/Photometry Data/'

In [241]:
fig,s = plot_individualSeshTrace_RwdAndNewHexLabeled(photrats,\
                                        75,plot_newlyAvail=False,plot_ref=True)

In [240]:
fig.savefig(loadpath+"s"+str(s)+"_individual_sesh_trace_rampRwdAndNewlyAvailEntry.pdf")

In [243]:
fig.savefig(loadpath+"s"+str(s)+"_individual_sesh_trace_rampAndRwd.pdf")

## plot peak/trough after port entry as a function of p(rwd)

In [9]:
photrats.set_plot_trace("green_z_scored")

In [7]:
photrats.pool_factor = "nom_rwd_chosen"
photrats.dat = photrats.df.loc[(photrats.df.rwd==1)&(photrats.df.tri>25),]
lowIndsRwd,midIndsRwd,highIndsRwd = photrats.getTriIndsByTerc(rwdtype="rwd")
photrats.dat = photrats.df.loc[(photrats.df.rwd!=1)&(photrats.df.tri>25),]
lowIndsOm,midIndsOm,highIndsOm = photrats.getTriIndsByTerc(rwdtype="om")

In [10]:
fig = plt.figure(figsize=(4,5.5))
plot_peakTroughDaDifAfterPortEntry_barWithRats(photrats,highIndsRwd,midIndsRwd,lowIndsRwd)
plot_peakTroughDaDifAfterPortEntry_barWithRats(photrats,highIndsOm,midIndsOm,lowIndsOm,peak=False,pltCol="blue")

In [15]:
def calc_DaChangeVprobCors(photrats):
    rwdCors = []
    omCors = []
    for rat in photrats.df.rat.unique():
        photrats.dat = photrats.df.loc[(photrats.df.rat==rat)\
                                       &(photrats.df.tri>25)&(photrats.df.rwd==1),]
        photrats.dat_visinds = photrats.dat.loc[photrats.dat.port!=-100].index
        lowInds,midInds,highInds = photrats.getTriIndsByTerc()
        daChanges = np.concatenate([calc_DaChangeAtIndsOneRat(photrats,highInds,peak=True),
                     calc_DaChangeAtIndsOneRat(photrats,midInds,peak=True),
                     calc_DaChangeAtIndsOneRat(photrats,lowInds,peak=True)])
        probs = np.concatenate([photrats.dat.loc[highInds,"nom_rwd_chosen"].values/100,\
        photrats.dat.loc[midInds,"nom_rwd_chosen"].values/100,\
        photrats.dat.loc[lowInds,"nom_rwd_chosen"].values/100])
        rwdCors.append(pearsonr(probs,daChanges))
        photrats.dat = photrats.df.loc[(photrats.df.rat==rat)\
                                       &(photrats.df.tri>25)&(photrats.df.rwd==0),]
        photrats.dat_visinds = photrats.dat.loc[photrats.dat.port!=-100].index
        lowInds,midInds,highInds = photrats.getTriIndsByTerc()
        daChanges = np.concatenate([calc_DaChangeAtIndsOneRat(photrats,highInds,peak=False),
                     calc_DaChangeAtIndsOneRat(photrats,midInds,peak=False),
                     calc_DaChangeAtIndsOneRat(photrats,lowInds,peak=False)])
        probs = np.concatenate([photrats.dat.loc[highInds,"nom_rwd_chosen"].values/100,\
        photrats.dat.loc[midInds,"nom_rwd_chosen"].values/100,\
        photrats.dat.loc[lowInds,"nom_rwd_chosen"].values/100])
        omCors.append(pearsonr(probs,daChanges))
        pd.DataFrame(rwdCors,columns=["coef","p-val"]).to_csv(photrats.directory_prefix+"pearsonR_result_DaVsRpe_rwd.csv")
        pd.DataFrame(omCors,columns=["coef","p-val"]).to_csv(photrats.directory_prefix+"pearsonR_result_DaVsRpe_om.csv")
    return rwdCors,omCors

In [16]:
rwdCors,omCors = calc_DaChangeVprobCors(photrats)
plt.suptitle("rwd pval = "+str(wilcoxon(np.array(rwdCors)[:,0])[1])+\
            "\n om pval = "+str(wilcoxon(np.array(omCors)[:,0])[1]))
fig.savefig(photrats.directory_prefix+"pRwdDAPeak_Trough_Dif_RPE_BarPlot_correctAlignment.pdf")

In [None]:
rwdCors = pd.read_csv(photrats.directory_prefix+"pearsonR_result_DaVsRpe_rwd_correctAlignment.csv")
omCors = pd.read_csv(photrats.directory_prefix+"pearsonR_result_DaVsRpe_om_correctAlignment.csv")

In [None]:
omCors.coef

In [17]:
wilcoxon(rwdCors.coef,alternative="less")

AttributeError: 'list' object has no attribute 'coef'

In [None]:
wilcoxon(omCors.coef,alternative="less")

- get new vs blocked hex entry info
  - first need to load transition matrices and barrier configs

In [29]:
photrats.directory_prefix = photrats.phot_directory_prefix
photrats.load_tmats()
photrats.get_barIDs()
photrats.directory_prefix = loadpath#photrats.directory_prefix+"

no tmat saved for sesh 0 block 7
no tmat saved for sesh 45 block 4
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
directory_prefix back to nas 8
briefly switching directory_prefix to new Nas...
dire

- plot port-aligned DA pooled by p(rwd) second half of blocks

In [27]:
create_triframe(photrats)

In [28]:
photrats.triframe.to_csv(photrats.directory_prefix+"triframe.csv")

In [20]:
photrats.get_portQvals(qtype="port",level="rat")
#get_portQvals(photrats,qtype="port",level="rat")

In [21]:
photrats.factor = "Q"
photrats.get_vals_byChosenEtc(chosen_only=True)

In [22]:
def plot_portAlignedDaInTime(photrats,secondHalfOnly=True,
                             poolFactor="nom_rwd_chosen",useRatGroupLevel=True):
    photrats.set_pool_factor(poolFactor)
    photrats.set_plot_trace("green_z_scored")
    fig = plt.figure(figsize = (7,5))
    rwdhigh,omhigh,rwdmid,ommid,rwdlow,omlow = photrats.getSessionTercMeans(\
        secondHalf=secondHalfOnly,useRat=useRatGroupLevel)
    
    high_color = "red"#"indianred"
    mid_color = "firebrick"
    low_color = "maroon"
    high_colorOm = "dodgerblue"
    mid_colorOm = "blue"
    low_colorOm ="darkblue"
    
    xvals = np.arange(photrats.fs*photrats.plot_window[0],photrats.fs*photrats.plot_window[1]+1)/photrats.fs
    ax1 = plt.gca()
    plot_avgWithSem(photrats,ax1,xvals,np.vstack((rwdhigh,omhigh)),'lightgrey','-',\
                    [None,-photrats.plot_window[0]*photrats.fs],"high")
    plot_avgWithSem(photrats,ax1,xvals,rwdhigh,high_color,'-',[-photrats.plot_window[0]*photrats.fs,None])
    plot_avgWithSem(photrats,ax1,xvals,omhigh,high_colorOm,':',[-photrats.plot_window[0]*photrats.fs,None])
    plot_avgWithSem(photrats,ax1,xvals,np.vstack((rwdmid,ommid)),'darkgrey','-',\
                    [None,-photrats.plot_window[0]*photrats.fs],"medium")
    plot_avgWithSem(photrats,ax1,xvals,rwdmid,mid_color,'-',[-photrats.plot_window[0]*photrats.fs,None])
    plot_avgWithSem(photrats,ax1,xvals,ommid,mid_colorOm,':',[-photrats.plot_window[0]*photrats.fs,None])
    ax1.axvline(x=0.0,ymin=-.1,ymax=1.0,color='k',linestyle='--')
    ax1.set_xlabel('time (s) from port entry')
    plot_avgWithSem(photrats,ax1,xvals,np.vstack((rwdlow,omlow)),'dimgrey','-',\
                    [None,-photrats.plot_window[0]*photrats.fs],"low")
    plot_avgWithSem(photrats,ax1,xvals,rwdlow,low_color,'-',[-photrats.plot_window[0]*photrats.fs,None])
    plot_avgWithSem(photrats,ax1,xvals,omlow,low_colorOm,':',[-photrats.plot_window[0]*photrats.fs,None])
    ax1.axvline(x=0.0,ymin=-.1,ymax=1.0,color='k',linestyle='--')
    ax1.set_xlabel('time (s) from port entry')
    ax1.legend()
    plt.xlabel("time from port entry (s)",fontsize='xx-large')
    plt.ylabel("DA (z-scored)",fontsize='xx-large')
    plt.tight_layout()
    return fig

In [23]:
photrats.set_plot_window([-5,5])
fig = plot_portAlignedDaInTime(photrats,secondHalfOnly=False,poolFactor="Q_chosen",useRatGroupLevel=True)
plt.xticks(np.arange(-5,6),np.arange(-5,6))
plt.ylim(-.9,2.2)

  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.n

(-0.9, 2.2)

In [24]:
fig.savefig(photrats.directory_prefix+"da_rpe_by_portQ_ratAvg_correctAlignment.pdf")

In [25]:
photrats.set_plot_window([-5,5])
fig = plot_portAlignedDaInTime(photrats,secondHalfOnly=True,poolFactor="nom_rwd_chosen",useRatGroupLevel=True)
plt.xticks(np.arange(-5,6),np.arange(-5,6))
plt.ylim(-.9,2.2)

  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.nanmean\
  ax.plot(xvals[subset[0]:subset[1]],np.nanmean\
  ax.fill_between(xvals[subset[0]:subset[1]],(np.nanmean\
  [subset[0]:subset[1]],(np.n

(-0.9, 2.2)

In [26]:
fig.savefig(photrats.directory_prefix+"da_rpe_byportQ_2ndHalf_ratAvg_correctAlignment.pdf")

In [324]:
print("done")

done


In [64]:
s = 95#photrats.df.loc[photrats.df.session_type=="prob","session"].unique()[0]
tridat = photrats.df.loc[(photrats.df.session==s)&(photrats.df.port!=-100),:].copy()
tridat.reset_index(inplace=True)
seshQs = tridat.loc[:,["Q_a","Q_b","Q_c"]].values
fig = plt.figure(figsize=(10,5))
x1 = np.arange(0,len(tridat))
A1 = tridat.rwd.loc[tridat.port==0] + 1.5
yA1 = np.zeros(len(tridat))
yA1[A1.index.values] = A1
B1 = tridat.rwd.loc[tridat.port==1] + 1.5
yB1 = np.zeros(len(tridat))
yB1[B1.index.values] = B1
C1 = tridat.rwd.loc[tridat.port==2] + 1.5
yC1 = np.zeros(len(tridat))
yC1[C1.index.values] = C1
ax1 = plt.subplot2grid((10,1),(0,0),colspan = 1, rowspan = 1)
ax1.bar(x1,yA1,color = '#1f77b4')
ax1.axis('off')
ax2 = plt.subplot2grid((10,1),(1,0),colspan = 1, rowspan = 1,sharex=ax1)
ax2.bar(x1,yB1,color = '#ff7f0e')
ax2.axis('off')
ax3 = plt.subplot2grid((10,1),(2,0),colspan = 1, rowspan = 1,sharex=ax1)
ax3.bar(x1,yC1,color = '#2ca02c')
ax3.axis('off')
ax4 = plt.subplot2grid((10,1),(3,0),colspan = 1, rowspan = 7,sharex=ax1)
ax4.plot(x1,seshQs[:,0],label = "port A",color="#1f77b4")
ax4.plot(x1,seshQs[:,1],label = "port B",color="#ff7f0e")
ax4.plot(x1,seshQs[:,2],label = "port C",color="#2ca02c")
plt.xlabel("Trial",fontsize=20,fontweight="bold")
plt.ylabel("Q Value",fontsize=20,fontweight="bold")
plt.legend()
plt.tight_layout()

In [57]:
fig.savefig(photrats.directory_prefix+"port_QvalEvolution_wTicks_s"+str(s)+".pdf")

## plot q rpe lag reg

In [65]:
photrats.bin_size = int(photrats.fs/15)
ratRwdRpes,ratOmRpes,ratRwdNs,ratOmNs = calcRpeRegByRatAndSesh(photrats,useQ=True)

100%|██████████| 10/10 [05:43<00:00, 34.36s/it]


In [82]:
def plot_rpeLagRegCoefs(photrats,binsize=100):
    xvals = np.arange(0,photrats.fs*2,250/(1000/binsize))/photrats.fs#np.arange(0,photrats.fs*2)/photrats.fs
    fig = plt.figure(figsize=(4.5,5))
    ax1 = plt.subplot2grid((6,4),(2,0),colspan = 4, rowspan =4)
    plot_ratMeans(xvals,ratRwdRpes,'darkred',pltLabel="reward")
    plot_ratMeans(xvals,ratOmRpes,'darkblue',pltLabel="omission")
    plt.xlabel("time from port entry (s)",fontsize='xx-large')
    plt.axhline(0,ls=':',color='k')
    plt.ylabel("RPE ß",fontsize='xx-large')
    plt.xlim(0,2)
    plt.legend()
    ax2 = plt.subplot2grid((6,4),(0,0),colspan = 4, rowspan =2,sharex=ax1)
    plot_sigPoints(xvals,ratRwdRpes,'darkred',plot99=False)
    plot_sigPoints(xvals,ratOmRpes,'darkblue',plot99=False)
    ax2.tick_params('x', labelbottom=False)
    plt.ylabel("Fraction\nsignificant")
    ax2.set_ylim(0,1)
    plt.tight_layout()
    fig.savefig(photrats.directory_prefix+"rpeLagReg_binned.pdf")

In [83]:
plot_rpeLagRegCoefs(photrats,binsize=1000/12.5)

In [38]:
regWeightsByRat = calc_choiceRegWeightsByRat(photrats)

In [18]:
np.array(photrats.triframe.rat.unique()).astype(str)

array(['IM-1272', 'IM-1273', 'IM-1276', 'IM-1291', 'IM-1292', 'IM-1322',
       'IM-1398', 'IM-1434', 'IM-1458', 'IM-1478'], dtype='<U7')

In [39]:
regWeightsByRat.loc[:,"rat"] = np.array(photrats.triframe.rat.unique())
regWeightsByRat

Unnamed: 0,intercept,relative p(R),relative distance,rat
0,0.788138,0.840152,-2.901712,IM-1272
1,-0.067526,0.81995,-1.460891,IM-1273
2,0.703347,0.881085,-2.522147,IM-1276
3,0.030442,0.549869,-1.716633,IM-1291
4,0.166691,1.393542,-1.911691,IM-1292
5,0.668904,1.152568,-2.430597,IM-1322
6,0.86557,0.666288,-2.379002,IM-1398
7,-0.615642,1.637882,-1.195741,IM-1434
8,-0.605195,1.244594,-0.726433,IM-1478


- plot difference between peak DA within 0.5s after port entry and DA at port entry.
  - first take average trace for each rat and THEN take difference, not average of differences

- plot DA aligned to new/blocked hex discovery
  - identify indices of blocked and new path discovery
  - identify indices of new path before enter and new path before ignored

In [32]:
portstrings = ['A','B','C']
newDistsToPort = []
oldDistsToPort = []
for i in range(len(newHexAdjInds)):
    if newHexAdjStates[i] == -1:
        newDistsToPort.append(np.nan)
        oldDistsToPort.append(np.nan)
        continue
    state = newHexAdjStates[i]
    sesh = photrats.df.loc[newHexAdjInds[i],'session']
    block = photrats.df.loc[newHexAdjInds[i],'block']
    chosenPort = photrats.df.port.replace(-100,method='bfill')[newHexAdjInds[i]]#photrats.df.loc[newHexAdjInds[i],'']
    newDistsToPort.append(photrats.sesh_hexDists['dto'+portstrings[chosenPort]][sesh][block-1][state])
    oldDistsToPort.append(photrats.sesh_hexDists['dto'+portstrings[chosenPort]][sesh][block-2][state])
newDistsToPort = np.array(newDistsToPort)
oldDistsToPort = np.array(oldDistsToPort)

In [33]:
#drop all entries where newHexAdjStates==-1
newHexAdjStates = np.array(newHexAdjStates)
newHexAdjInds = np.array(newHexAdjInds)
toRemove = np.where(newHexAdjStates==-1)[0]
newHexAdjStates = np.delete(newHexAdjStates,toRemove)
newHexAdjInds = np.delete(newHexAdjInds,toRemove)
newDistsToPort = np.delete(newDistsToPort,toRemove)
oldDistsToPort = np.delete(oldDistsToPort,toRemove)

In [71]:
newDistsToPort

array([ 4., 10.,  9.,  7., 15., 16.,  9.,  8.,  9.,  9., 17.,  7.,  4.,
        7.,  4.,  9.,  8.,  6., 11.,  8.,  8.,  4., 15.,  9.,  8.,  5.,
        8.,  9.,  9.,  7.,  9.,  8., 10., nan,  7.,  7.,  9., nan,  5.,
       15.,  4., 14.,  8.,  8., 10.,  4.,  7.,  8., 10., 12.,  7., 11.,
        9.,  7., 15., 16., 12.,  8.,  4., 10., 18.,  7., 11.,  7., 10.,
        8.,  6.,  5., 10.,  7., 11.,  7.,  9.,  6.,  7.,  4.,  4.,  6.,
        9.,  6.,  7.,  6.,  7., 18.,  7.,  8.,  5., 10.,  9., 13., 10.,
        8.,  9., 10.,  7.,  6.,  9.,  7., 12., 11.,  7.,  9.,  5., 13.])

In [73]:
relDist2Port = newDistsToPort-oldDistsToPort

In [74]:
adjHexIndsSortedByRelDist = newHexAdjInds[relDist2Port.argsort()]
#relDist2Port = relDist2Port[relDist2Port.argsort()]
#toRemove = np.where(np.isnan(relDist2Port))[0]
#adjHexIndsSortedByRelDist = np.delete(adjHexIndsSortedByRelDist,toRemove)
#relDist2Port = np.delete(relDist2Port,toRemove)

In [34]:
adjHexIndsSortedByDist = newHexAdjInds[newDistsToPort.argsort()]
#newDistsToPort = newDistsToPort[newDistsToPort.argsort()]
#toRemove = np.where(np.isnan(newDistsToPort))[0]
#adjHexIndsSortedByDist = np.delete(adjHexIndsSortedByDist,toRemove)
#newDistsToPort = np.delete(newDistsToPort,toRemove)

In [35]:
photrats.set_plot_trace("green_z_scored")
photrats.set_plot_window([-5,5])

In [36]:
adjHexNextPortVal = photrats.df.loc[newHexAdjInds,"nom_rwd_chosen"].values
adjHexIndsSortedByPortVal = newHexAdjInds[adjHexNextPortVal.argsort()]
#adjHexNextPortVal = adjHexNextPortVal[adjHexNextPortVal.argsort()]

In [37]:
def get_newPathTracesByDistToPort(adjHexIndsSortedByDist):
    smoothWin = int(photrats.fs/4)
    shortTrace = []
    midTrace = []
    longTrace = []
    terc_cutoff = int(len(adjHexIndsSortedByDist)/3)
    for i in range(len(adjHexIndsSortedByDist)):
        adjInd = adjHexIndsSortedByDist[i]
        trace = photrats.df.loc[adjInd+photrats.plot_window[0]*photrats.fs:\
                            adjInd+photrats.plot_window[1]*photrats.fs,photrats.plot_trace].rolling(smoothWin).mean().values
        if i<=terc_cutoff:
            shortTrace.append(trace)
        elif i<=terc_cutoff*2:
            midTrace.append(trace)
        else:
            longTrace.append(trace)
    return shortTrace,midTrace,longTrace

def get_newPathTracesByDistToPort_absoluteDist(newHexAdjInds,distsToPort,dist_cutoff=7):
    smoothWin = int(photrats.fs/4)
    shortTrace = []
    midTrace = []
    longTrace = []
    for i in range(len(newHexAdjInds)):
        adjInd = newHexAdjInds[i]
        dist = distsToPort[i]
        trace = photrats.df.loc[adjInd+photrats.plot_window[0]*photrats.fs:\
                adjInd+photrats.plot_window[1]*photrats.fs,photrats.plot_trace].rolling(smoothWin).mean().values
        if dist<=dist_cutoff:
            shortTrace.append(trace)
        elif dist<=dist_cutoff*2:
            midTrace.append(trace)
        else:
            longTrace.append(trace)
    return shortTrace,midTrace,longTrace

### now do it by rat

In [63]:
newHexAdjStates = np.delete(newHexAdjStates,toRemove)
newHexAdjInds = np.delete(newHexAdjInds,toRemove)

In [38]:
ratNewPathLens = {r:[] for r in photrats.df.rat.unique()}
for rat in photrats.df.rat.unique():
    ratNewPathLens[rat] = newDistsToPort[np.where(np.isin(newHexAdjInds,photrats.df.loc[photrats.df.rat==rat].index))[0]]
print(ratNewPathLens)


{'IM-1272': array([ 4., 10.,  9.,  7., 15., 16.,  9.,  8.,  9.,  9., 17.,  7.,  4.,
        7.,  4.,  9.,  8.,  6.]), 'IM-1273': array([11.,  8.,  8.]), 'IM-1276': array([ 4., 15.,  9.,  8.,  5.,  8.,  9.,  9.,  7.,  9.,  8., 10., nan,
        7.,  7.,  9., nan,  5.]), 'IM-1291': array([15.,  4., 14.,  8.,  8., 10.,  4.,  7.,  8.]), 'IM-1292': array([10., 12.,  7., 11.,  9.,  7., 15., 16., 12.,  8.,  4., 10., 18.,
        7.]), 'IM-1322': array([11.,  7., 10.,  8.,  6.,  5., 10.,  7., 11.,  7.,  9.,  6.,  7.,
        4.,  4.]), 'IM-1398': array([ 6.,  9.,  6.,  7.,  6.,  7., 18.,  7.]), 'IM-1434': array([ 8.,  5., 10.,  9., 13., 10.,  8.]), 'IM-1478': array([ 9., 10.,  7.,  6.,  9.,  7., 12., 11.,  7.,  9.]), 'IM-1532': array([ 5., 13., 13., 12.])}


In [69]:
with open(loadpath+"lenFromNewHexNumsByRat.txt", 'w') as f:
    for key, value in ratNewPathLens.items():
        f.write(f'{key}: {value}\n')

In [161]:
ratsWithSufficientNumbers = ['IM-1272','IM-1276','IM-1291',\
            'IM-1292','IM-1322','IM-1398','IM-1434','IM-1478']

In [165]:
len(ratNewHexAdjInds)

10

In [164]:
n_plots = len(ratsWithSufficientNumbers)
n_cols = int(np.ceil(np.sqrt(n_plots)))
n_rows = int(np.ceil(n_plots / n_cols))
dist_cut = 6
use_relativeDist = False

fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(16,10))

xvals = np.arange(photrats.plot_window[0]*photrats.fs,photrats.plot_window[1]*photrats.fs+1)/photrats.fs

for i, ax in enumerate(axes.flatten()):
    if i < n_plots:
        rat = ratsWithSufficientNumbers[i]
        ratInds = np.where(np.isin(newHexAdjInds,photrats.df.loc[photrats.df.rat==rat].index))[0]
        ratNewHexAdjInds = newHexAdjInds[ratInds]
        ratNewHexDists = relDist2Port[ratInds] if use_relativeDist else newDistsToPort[ratInds]
        sortedRatInds = ratNewHexAdjInds[ratNewHexDists.argsort()]
        shortTrace,midTrace,longTrace = get_newPathTracesByDistToPort(sortedRatInds)
        #shortTrace,midTrace,longTrace = get_newPathTracesByDistToPort_absoluteDist(\
        #                        sortedRatInds,ratNewHexDists,dist_cutoff=dist_cut)
        ax.set_title(rat+"; "+str(len(shortTrace))+" short; "+str(len(midTrace))+" mid; "\
                    +str(len(longTrace))+" long")
        if len(shortTrace)>0:
            ax.plot(xvals,np.mean(shortTrace,axis=0),color='hotpink',label='short')
        if len(midTrace)>0:
            ax.plot(xvals,np.mean(midTrace,axis=0),color='red',label='mid')
        if len(longTrace)>0:
            ax.plot(xvals,np.mean(longTrace,axis=0),color='darkred',label='long')
        ax.axvline(x=0,ls='--',color='k',alpha=.8,lw=1)
        if (i + n_cols) >= n_plots:# i >= n_cols * (n_rows - 1):
            ax.set_xticks(np.arange(-5,6))
            ax.set_xlabel('time from hex entry (s)')
        else:
            ax.set_xticks([])
        # Only show the y axis if the subplot is in the first column
        if i % n_cols == 0:
            ax.set_ylabel('Mean DA (Z)')
        if i==0:
            ax.legend()
    else:
        # Remove any extra axes that are not needed
        ax.remove()

In [96]:
#fig.savefig(loadpath+"newPathDAbyDistToPort_byRat_"+str(dist_cut)+"hexGroups.pdf")
fig.savefig(loadpath+"newPathDAbyDistToPort_byRat_distTercile.pdf")

In [130]:
ratShortTraces = []
ratMidTraces = []
ratLongTraces = []
for rat in photrats.df.rat.unique():
    if rat == "IM-1273":
        continue
    ratInds = np.where(np.isin(newHexAdjInds,photrats.df.loc[photrats.df.rat==rat].index))[0]
    ratNewHexAdjInds = newHexAdjInds[ratInds]
    ratNewHexDists = newDistsToPort[ratInds]
    sortedRatInds = ratNewHexAdjInds[ratNewHexDists.argsort()]
    shortTrace,midTrace,longTrace = get_newPathTracesByDistToPort_absoluteDist(sortedRatInds,ratNewHexDists,dist_cutoff=7)
    ratShortTraces.append(np.mean(shortTrace,axis=0))
    ratMidTraces.append(np.mean(midTrace,axis=0))
    ratLongTraces.append(np.mean(longTrace,axis=0))

In [None]:
plt.figure()
xvals = np.arange(photrats.plot_window[0]*photrats.fs,photrats.plot_window[1]*photrats.fs+1)/photrats.fs
toplt = np.mean(ratShortTraces,axis=0)
plt.plot(xvals,toplt,color='hotpink',label='short')
plt.fill_between(xvals,toplt+sem(ratShortTraces),toplt-sem(ratShortTraces),color='hotpink',alpha=0.5)
toplt = np.mean(ratMidTraces,axis=0)
plt.plot(xvals,toplt,color='red',label='mid')
plt.fill_between(xvals,toplt+sem(ratMidTraces),toplt-sem(ratMidTraces),color='red',alpha=0.5)
toplt = np.mean(ratLongTraces,axis=0)
plt.plot(xvals,toplt,color='darkred',label='long')
plt.fill_between(xvals,toplt+sem(ratLongTraces),toplt-sem(ratLongTraces),color='darkred',alpha=0.5)
plt.axvline(x=0,ls='--',color='k',alpha=.8,lw=1)
plt.xticks(np.arange(-5,6))
plt.legend()
plt.xlabel("time from hex entry (s)",fontsize="xx-large")
plt.ylabel("Mean DA (Z)",fontsize="xx-large")
plt.title("Average of rat averages")
plt.tight_layout()

In [66]:
%matplotlib qt

## plot newly available by distance to goal port from the current hex
- make sure I have shortestDistToPort for each block
- for each adjacent hex entered
  - identify the port the rat ran to
  - identify shortest distance to that port from the adjacent hex
  - save the length paired with the adjacent hex index in an array
  - identify the prior distance to the chosen port from the adjacent hex (block b-1) and save in array
  

In [43]:
shortTrace,midTrace,longTrace = get_newPathTracesByDistToPort(adjHexIndsSortedByDist)#,relDist2Port,dist_cut)
fig = plt.figure(figsize=(7,6))
xvals = np.arange(photrats.plot_window[0]*photrats.fs,photrats.plot_window[1]*photrats.fs+1)/photrats.fs
toplt = np.mean(shortTrace,axis=0)
plt.plot(xvals,toplt,color='hotpink',label='short')
plt.fill_between(xvals,toplt+sem(shortTrace),toplt-sem(shortTrace),color='hotpink',alpha=0.5)
toplt = np.mean(midTrace,axis=0)
plt.plot(xvals,toplt,color='red',label='mid')
plt.fill_between(xvals,toplt+sem(midTrace),toplt-sem(midTrace),color='red',alpha=0.5)
toplt = np.mean(longTrace,axis=0)
plt.plot(xvals,toplt,color='darkred',label='long')
plt.fill_between(xvals,toplt+sem(longTrace),toplt-sem(longTrace),color='darkred',alpha=0.5)
plt.axvline(x=0,ls='--',color='k',alpha=.8,lw=1)
plt.xticks(np.arange(-5,6))
plt.legend()
plt.xlabel("time from hex entry (s)",fontsize="xx-large")
plt.ylabel("Mean DA (Z)",fontsize="xx-large")
plt.title("Sorted by hex distance tercile\nAll rats pooled together.")
plt.ylim(-.6,2.8)
plt.tight_layout()
fig.savefig(photrats.directory_prefix+"newPathDAbyDistToPort_allRatsPooled_byTercile.pdf")

In [87]:
dist_cut = 7
shortTrace,midTrace,longTrace = get_newPathTracesByDistToPort_absoluteDist(newHexAdjInds,\
                                                                           newDistsToPort,dist_cut)
fig = plt.figure(figsize=(7,6))
xvals = np.arange(photrats.plot_window[0]*photrats.fs,photrats.plot_window[1]*photrats.fs+1)/photrats.fs
toplt = np.mean(shortTrace,axis=0)
plt.plot(xvals,toplt,color='hotpink',label='short')
plt.fill_between(xvals,toplt+sem(shortTrace),toplt-sem(shortTrace),color='hotpink',alpha=0.5)
toplt = np.mean(midTrace,axis=0)
plt.plot(xvals,toplt,color='red',label='mid')
plt.fill_between(xvals,toplt+sem(midTrace),toplt-sem(midTrace),color='red',alpha=0.5)
toplt = np.mean(longTrace,axis=0)
plt.plot(xvals,toplt,color='darkred',label='long')
plt.fill_between(xvals,toplt+sem(longTrace),toplt-sem(longTrace),color='darkred',alpha=0.5)
plt.axvline(x=0,ls='--',color='k',alpha=.8,lw=1)
plt.xticks(np.arange(-5,6))
plt.legend()
plt.xlabel("time from hex entry (s)",fontsize="xx-large")
plt.ylabel("Mean DA (Z)",fontsize="xx-large")
plt.title("Sorted by hex distance (groups of 6) from port\nAll rats pooled together.")
plt.tight_layout()
fig.savefig(photrats.directory_prefix+"newPathDAbyDistToPort_allRatsPooled_"+str(dist_cut)+"hexGroups.pdf")

In [171]:
adjHexNextPortVal[adjHexNextPortVal.argsort()]

array([10., 10., 10., 10., 10., 10., 10., 10., 20., 20., 20., 20., 20.,
       20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20.,
       50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50.,
       50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50.,
       50., 50., 50., 50., 50., 50., 50., 50., 50., 80., 80., 80., 80.,
       80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80.,
       80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80.,
       80., 80., 80., 90., 90., 90., 90., 90., 90., 90., 90., 90., 90.],
      dtype=float16)

In [42]:

rwdCut=30
shortTrace,midTrace,longTrace = get_newPathTracesByDistToPort_absoluteDist(adjHexIndsSortedByPortVal,\
                                        adjHexNextPortVal[adjHexNextPortVal.argsort()],rwdCut)
fig = plt.figure(figsize=(7,6))
xvals = np.arange(photrats.plot_window[0]*photrats.fs,photrats.plot_window[1]*photrats.fs+1)/photrats.fs
toplt = np.mean(shortTrace,axis=0)
plt.plot(xvals,toplt,color='darkred',label='low')
plt.fill_between(xvals,toplt+sem(shortTrace),toplt-sem(shortTrace),color='darkred',alpha=0.5)
toplt = np.mean(midTrace,axis=0)
plt.plot(xvals,toplt,color='red',label='mid')
plt.fill_between(xvals,toplt+sem(midTrace),toplt-sem(midTrace),color='red',alpha=0.5)
toplt = np.mean(longTrace,axis=0)
plt.plot(xvals,toplt,color='hotpink',label='high')
plt.fill_between(xvals,toplt+sem(longTrace),toplt-sem(longTrace),color='hotpink',alpha=0.5)
plt.axvline(x=0,ls='--',color='k',alpha=.8,lw=1)
plt.xticks(np.arange(-5,6))
plt.legend()
plt.xlabel("time from hex entry (s)",fontsize="xx-large")
plt.ylabel("Mean DA (Z)",fontsize="xx-large")
plt.title("Sorted by port value tercile\nAll rats pooled together.")
plt.tight_layout()
fig.savefig(photrats.directory_prefix+"newPathDAbyPortVal_allRatsPooled_byCutoff.pdf")

In [40]:

shortTrace,midTrace,longTrace = get_newPathTracesByDistToPort(adjHexIndsSortedByPortVal)
fig = plt.figure(figsize=(7,6))
xvals = np.arange(photrats.plot_window[0]*photrats.fs,photrats.plot_window[1]*photrats.fs+1)/photrats.fs
toplt = np.mean(shortTrace,axis=0)
plt.plot(xvals,toplt,color='darkred',label='low')
plt.fill_between(xvals,toplt+sem(shortTrace),toplt-sem(shortTrace),color='darkred',alpha=0.5)
toplt = np.mean(midTrace,axis=0)
plt.plot(xvals,toplt,color='red',label='mid')
plt.fill_between(xvals,toplt+sem(midTrace),toplt-sem(midTrace),color='red',alpha=0.5)
toplt = np.mean(longTrace,axis=0)
plt.plot(xvals,toplt,color='hotpink',label='high')
plt.fill_between(xvals,toplt+sem(longTrace),toplt-sem(longTrace),color='hotpink',alpha=0.5)
plt.axvline(x=0,ls='--',color='k',alpha=.8,lw=1)
plt.xticks(np.arange(-5,6))
plt.legend()
plt.xlabel("time from hex entry (s)",fontsize="xx-large")
plt.ylabel("Mean DA (Z)",fontsize="xx-large")
plt.title("Sorted by port value tercile\nAll rats pooled together.")
plt.tight_layout()
plt.ylim(-.9,3.0)
fig.savefig(photrats.directory_prefix+"newPathDAbyPortVal_allRatsPooled_byTercile.pdf")

## should also plot by the p(rwd) of the subsequently entered port. And the p(rwd) of the port that hex is now closest to - the p(rwd) of the port the hex used to be closest to

In [30]:
photrats.get_newlyAvailHexesBySesh()
photrats.get_newlyBlockedHexesBySesh()
#add entries to df
photrats.add_newlyAvailHexesToDf()
photrats.add_adjacent2newlyAvail()
photrats.add_adjacent2newlyBlocked()

newHexInds,newHexStates,adjHexStates,adjHexInds = find_newHexEntryAndPriorHexInds(photrats)
blockedHexAdjInds,blockedHexAdjStates = find_blockedHexAdjInds(photrats)

blockedHexAdjInds = np.array(blockedHexAdjInds)[np.where(~np.isnan(blockedHexAdjInds))]
adjHexInds = np.array(adjHexInds)[np.where(~np.isnan(adjHexInds))]
adjHexInds = adjHexInds[adjHexInds!=-1]
blockedHexAdjInds = blockedHexAdjInds[blockedHexAdjInds!=-1]

newHexAdjInds,newHexAdjStates,enteredHex,enteredHexSoon = find_newHexAdjInds(photrats)
enteredInds,ignoredInds = find_enteredVignoredNewlyAvailInds(newHexAdjInds,enteredHex)
enteredInds = np.array(enteredInds)
ignoredInds = np.array(ignoredInds)

No previous adjacent hex entry detected for session  0  block  2
session  1  block  5
session  28  block  3
session  22  block  3
No previous adjacent hex entry detected for session  32  block  1
session  57  block  3
session  59  block  3
session  61  block  2
No previous adjacent hex entry detected for session  64  block  1
No previous adjacent hex entry detected for session  89  block  1
No previous adjacent hex entry detected for session  89  block  2
No previous adjacent hex entry detected for session  93  block  1
No previous adjacent hex entry detected for session  93  block  2
No previous adjacent hex entry detected for session  98  block  1
No previous adjacent hex entry detected for session  102  block  1
session  102  block  3
session  104  block  3
session  1  block  2
session  1  block  5
session  28  block  3
session  22  block  3
session  57  block  3
session  59  block  3
session  102  block  3
session  104  block  3
session  1  block  5
session  28  block  3
session  2

In [31]:
photrats.get_distanceToPort(getDistsFromDeadEnds=False)

# identify da change at hex discovery. add to df where other columns are distance to port and pRwd at port. then run regression of DA to distance and pRwd. If signal is a bonus - expected reward, makes sense that magnitude would scale inversely with val of approached port.

In [97]:
newHexDAchangeMeans,_ = calc_DaPeakDiffAfterNewPathInds(photrats,newHexAdjInds)

In [98]:
np.shape(newHexDAchangeMeans)

(10,)

In [123]:
def calc_DaPeakIndividualDiffsAfterNewPathInds(photrats,indices):
    photrats.set_plot_window([-1,0.25])
    tracePeakRats = photrats.df.loc[indices,"rat"].astype(str).values
    tracePeakChanges = []
    missingInds = []
    for rat in photrats.df.rat.unique():
        if rat not in tracePeakRats:
            missingInds.append(np.where(photrats.df.rat.unique()==rat)[0][0])
            continue
        tracesPost = get_TracesAroundIndex(photrats,indices[tracePeakRats==rat])
        bline = tracesPost[:,0]#np.mean(tracesPost,axis=0)[0]
        #tracePost = np.mean(tracesPost,axis=0)[photrats.fs*1:]
        daChanges = np.max(tracesPost[:,photrats.fs*1:],axis=1)-bline
        tracePeakChanges += list(daChanges)
    return tracePeakChanges

In [114]:
def calc_DaPeakDiffAfterNewPathInds(photrats,indices):
    photrats.set_plot_window([-1,0.25])
    tracePeakRats = photrats.df.loc[indices,"rat"].astype(str).values
    tracePeakRatMeans = []
    missingInds = []
    for rat in photrats.df.rat.unique():
        if rat not in tracePeakRats:
            missingInds.append(np.where(photrats.df.rat.unique()==rat)[0][0])
            continue
        tracesPost = get_TracesAroundIndex(photrats,indices[tracePeakRats==rat])
        bline = tracesPost[:,0]#np.mean(tracesPost,axis=0)[0]
        #tracePost = np.mean(tracesPost,axis=0)[photrats.fs*1:]
        daChanges = np.max(tracesPost[:,photrats.fs*1:],axis=1)-bline
        tracePeakRatMeans.append(np.mean(daChanges))#max(tracePost)-bline)
    return tracePeakRatMeans,missingInds

In [124]:
newHexDaChanges = calc_DaPeakIndividualDiffsAfterNewPathInds(photrats,newHexAdjInds)

In [179]:
newHexRegDf = pd.DataFrame({"DA_change":newHexDaChanges,"dist2port":newDistsToPort,\
              "portVal":adjHexNextPortVal/100,'rat':photrats.df.loc[\
                                        newHexAdjInds,"rat"].astype(str).values})
newHexRegDf = newHexRegDf.loc[(newHexRegDf.notnull().all(axis=1)),:]

In [186]:
newHexRegDf.loc[:,"dist2port"] = valscale.fit_transform(\
        newHexRegDf.loc[:,"dist2port"].values.reshape(-1,1))
newHexRegDf.loc[:,"portVal"] = valscale.fit_transform(\
        newHexRegDf.loc[:,"portVal"].values.reshape(-1,1))

In [200]:
ratBetas = []
for rat in ratsWithSufficientNumbers:
    distFromNewDfRat = newHexRegDf.loc[newHexRegDf.rat==rat,:]
    y = distFromNewDfRat.loc[:,"DA_change"]
    X = distFromNewDfRat.loc[:,["dist2port","portVal"]]
    betas,pvals = run_smLinRegWithPval(X,y)
    ratBetas.append(betas)

  warn("omni_normtest is not valid with less than 8 observations; %i "
  warn("omni_normtest is not valid with less than 8 observations; %i "


In [201]:
ratBetas

[array([ 4.2819347 , -0.2860685 , -3.43139157]),
 array([ 3.15090363,  0.44409047, -1.89022239]),
 array([ 4.63947051,  0.69144871, -1.0705713 ]),
 array([ 1.24829136,  4.17295315, -1.27010763]),
 array([  8.94329217, -12.39171072,  -3.0183711 ]),
 array([ 7.21973196, -5.51762129, -3.21643269]),
 array([ 2.28281531, -1.07706806,  0.21844859]),
 array([  3.64496033, -11.2918529 ,   4.44983239])]

In [234]:
fig = plt.figure()
sns.barplot(data=pd.DataFrame(ratBetas,columns=["icept","dist2port","pRwd"]),ci=95,color='grey',alpha=0.5)
sns.stripplot(data=pd.DataFrame(ratBetas,columns=["icept","dist2port","pRwd"]),\
              marker='D',color='deeppink',edgecolor='k',size=8,linewidth=3)
plt.ylabel(r"Regression $\beta \pm 95\% ci$ ")
plt.tight_layout()
fig.savefig(photrats.directory_prefix+"newHexDistAndProbBetas.pdf")

In [154]:
def run_smLinRegWithPval(X,y):
    X = np.hstack([np.ones(len(X)).reshape(-1,1),X])
    mod = sm.OLS(y, X).fit()
    coefs = mod.summary2().tables[1]['Coef.'].values
    pvals = mod.summary2().tables[1]["P>|t|"].values
    return coefs,pvals

In [155]:
y=newHexRegDf.loc[:,"DA_change"]
X = newHexRegDf.loc[:,["dist2port","portVal"]]
betas,pvals = run_smLinRegWithPval(X,y)

## betas for distance and port val are not significant. Should run regression for each animal that had enough data. Otherwise could do barplot and test if categories are different.

In [111]:
np.max(tracesPost[:,photrats.fs*1:],axis=1)

array([ 3.146  ,  5.     ,  1.124  ,  3.227  ,  0.5205 ,  5.992  ,
        0.5254 ,  4.54   ,  0.0371 ,  3.217  , -0.05737,  0.1938 ,
        0.4473 ,  0.6143 ,  3.06   ,  0.6396 ,  1.342  ,  1.486  ],
      dtype=float16)

# I was taking mean trace max - mean trace baseline, instead of mean of (max-baseline) for each trace

In [113]:
tracePeakRatMeans

[2.254, 5.824, 2.08, 4.2, 2.326, 4.516, 3.938, 1.992, 2.523, 2.266]

In [99]:
newHexDAchangeMeans

[1.442, 4.48, 0.889, 3.176, 1.177, 2.91, 3.246, 1.249, 1.642, 2.266]

In [86]:
photrats.set_plot_trace("green_z_scored")

In [618]:
availMeans,_ = calc_DaPeakDiffAfterNewPathInds(photrats,np.concatenate([enteredInds,ignoredInds]))

In [115]:
blockedMeans,missingInd = calc_DaPeakDiffAfterNewPathInds(photrats,blockedHexAdjInds)

In [116]:
blockedMeans

[1.214, 2.059, 0.6025, 0.8047, 2.543, 1.237, 2.438, 0.8555, 1.56, 1.081]

In [118]:
get_sigRatsPaired_from2samples(tracePeakRatMeans,blockedMeans,"greater")

p-value =  0.001953125


(False, True, True)

In [122]:
wilcoxon(tracePeakRatMeans)

WilcoxonResult(statistic=0.0, pvalue=0.001953125)

In [119]:
wilcoxon(blockedMeans)

WilcoxonResult(statistic=0.0, pvalue=0.001953125)

In [94]:
fig = plot_meanRatDaChangeAfterHexEntry(photrats,np.concatenate([enteredInds,ignoredInds]),blockedHexAdjInds,pltCol1="deeppink",pltCol2="k")
#plt.ylim(-.6,2.6)#-.5,2.3)
plt.xticks([0,1],["newly\navailable","newly\nblocked"])
plt.tight_layout()
fig.savefig(photrats.directory_prefix+"DaChangeNewlyAvailVblockedBarplot_firstDiscovery.pdf")
#fig.savefig(photrats.directory_prefix+photrats.plot_trace+"ChangeNewlyAvailVblockedBarplot_firstDiscovery.pdf")

blocked hex: 
avail hex: 
paired test
p-value =  0.0009765625


In [102]:
get_sigRats_fromMeanList(availMeans)

0.001953125


(False, True, True)

In [146]:
def plotFirstEntryHexChange(photrats,adjHexInds,blockedHexAdjInds,legend_on=False):
    #photrats.set_plot_trace("green_z_scored")
    photrats.set_plot_window([-5,5])
    smoothWin = int(photrats.fs/4)
    fig = plt.figure(figsize=(7,5))#(4.8,5))
    xvals = np.arange(photrats.plot_window[0]*photrats.fs,photrats.plot_window[1]*photrats.fs+1)/photrats.fs
    
    avail_traces,blocked_traces = get_availAndBlockedTraces(photrats,adjHexInds,blockedHexAdjInds)
    
    #toplt = pd.Series(np.median(avail_traces,axis=0)).rolling(smoothWin).mean().values
    toplt = pd.Series(np.mean(avail_traces,axis=0)).rolling(smoothWin).mean().values
    topltSem = pd.Series(sem(avail_traces,axis=0)).rolling(smoothWin).mean().values
    plt.plot(xvals,toplt,label="Newly available",color='deeppink',lw=3)
    plt.fill_between(xvals,toplt-topltSem,toplt+topltSem,color='deeppink',alpha=.3)
    toplt = pd.Series(np.mean(blocked_traces,axis=0)).rolling(smoothWin).mean().values
    topltSem = pd.Series(sem(blocked_traces,axis=0)).rolling(smoothWin).mean().values
    plt.plot(xvals,toplt,label="Newly blocked",color='k',ls=':',lw=3)
    plt.fill_between(xvals,toplt-topltSem,toplt+topltSem,color='k',ls=':',alpha=.3)
    plt.xlabel("time from hex entry (s)",fontsize="xx-large")
    #plt.ylabel("median z-scored DA",fontsize="xx-large"")
    plt.ylabel("Mean z-scored DA",fontsize="xx-large")
    plt.axvline(x=0,ls='--',color='k',alpha=.8,lw=1)
    plt.xticks(np.arange(-5,6))
    #plt.ylim([-.4,1.3])
    if legend_on:
        plt.legend()
    plt.tight_layout()
    return fig

In [149]:
fig = plotFirstEntryHexChange(photrats,np.concatenate([enteredInds,ignoredInds]),blockedHexAdjInds,legend_on=False)
#plt.ylim(-.9,2.2)
#fig.savefig(photrats.directory_prefix+"DaNewlyAvailVblocked_mean.pdf")
#fig.savefig(photrats.directory_prefix+photrats.plot_trace+"NewlyAvailVblocked_mean.pdf")

In [151]:
plt.ylabel("Mean running speed (cm/s)")
fig.savefig(photrats.directory_prefix+photrats.plot_trace+"NewlyAvailVblocked_mean.pdf")

- plot DA aligned to new hex discovery pooled by whether rat entered or ignored new path

In [None]:
def get_availAndBlockedTraces(photrats,adjHexInds,blockedHexAdjInds):
    avail_traces = []
    blocked_traces = []
    for i in adjHexInds:
        if np.isnan(i) or i == -1:
            continue
        avail_traces.append(photrats.df.loc[i+photrats.fs*photrats.plot_window[0]:\
                i+photrats.fs*photrats.plot_window[1],photrats.plot_trace].values)
    for i in blockedHexAdjInds:
        if np.isnan(i):
            continue
        blocked_traces.append(photrats.df.loc[i+photrats.fs*photrats.plot_window[0]:\
                i+photrats.fs*photrats.plot_window[1],photrats.plot_trace].values)
    return np.array(avail_traces),np.array(blocked_traces)

In [263]:
def plotFirstAdjEntryByEnteredVsIgnored(photrats,enteredInds,ignoredInds,legend_on=False,pltCol1='deeppink',pltCol2='k',ls2='-'):
    photrats.set_plot_trace("green_z_scored")
    photrats.set_plot_window([-5,5])
    smoothWin = int(photrats.fs/4)
    fig = plt.figure(figsize=(7,5))
    xvals = np.arange(photrats.plot_window[0]*photrats.fs,photrats.plot_window[1]*photrats.fs+1)/photrats.fs
    
    avail_traces,blocked_traces = get_availAndBlockedTraces(photrats,enteredInds,ignoredInds)
    
    toplt = pd.Series(np.mean(avail_traces,axis=0)).rolling(smoothWin).mean().values
    topltSem = pd.Series(sem(avail_traces,axis=0)).rolling(smoothWin).mean().values
    plt.plot(xvals,toplt,label="entered",color=pltCol1,lw=3)
    plt.fill_between(xvals,toplt-topltSem,toplt+topltSem,color=pltCol1,alpha=.2)
    toplt = pd.Series(np.mean(blocked_traces,axis=0)).rolling(smoothWin).mean().values
    topltSem = pd.Series(sem(blocked_traces,axis=0)).rolling(smoothWin).mean().values
    plt.plot(xvals,toplt,label="ignored",ls=ls2,color=pltCol2,lw=3)
    plt.fill_between(xvals,toplt-topltSem,toplt+topltSem,color=pltCol2,alpha=.2)
    plt.xlabel("Time from changed-hex discovery (s)",fontsize="xx-large")
    plt.ylabel("Mean z-scored DA",fontsize="xx-large")
    plt.axvline(x=0,ls='--',color='k',alpha=.9,lw=1)
    plt.xticks(np.arange(-5,6))
    if legend_on:
        plt.legend()
    plt.tight_layout()
    return fig

In [29]:
photrats.set_plot_trace("green_z_scored")
fig = plotFirstAdjEntryByEnteredVsIgnored(photrats,enteredInds,ignoredInds,legend_on=True,pltCol1="#27aeef",pltCol2= "#b33dc6",ls2='-')
plt.ylim(-.9,2.2)
fig.savefig(photrats.directory_prefix+"DaEnteredVsIgnoredNewlyAvail_mean.pdf")

In [73]:
ignoredInclusionRats = np.array(list(ignored_n_PerRat.keys()))[np.where(np.array(list(ignored_n_PerRat.values()))>=3)]
inclusionInds = photrats.df.loc[photrats.df.rat.isin(ignoredInclusionRats)].index
ignoredInds2include = ignoredInds[np.where(np.isin(ignoredInds,inclusionInds))]

In [59]:
def plot_meanRatDaChangeAfterHexEntry(photrats,adjHexInds,blockedHexAdjInds,pltCol1="#27aeef",pltCol2= "#b33dc6"):
    #availRatMeans,blockedRatMeans = calc_DaChangeAtHexEntry(photrats,adjHexInds,blockedHexAdjInds)
    availMeans,_ = calc_DaPeakDiffAfterNewPathInds(photrats,adjHexInds)
    blockedMeans,missingInd = calc_DaPeakDiffAfterNewPathInds(photrats,blockedHexAdjInds)
    fig = plt.figure(figsize=(4,5.5))
    plt.bar([0,1],[np.mean(availMeans),\
            np.mean(blockedMeans)],color='k',alpha=0.3)
    plt.ylabel("mean $\Delta$DA",fontsize='xx-large',fontweight='bold')
    plt.xlabel("hex type",fontsize='xx-large',fontweight='bold')
    for r in range(len(availMeans)):
        plt.scatter(x=0,y=np.mean(availMeans[r]),color=pltCol1,marker='o')
        try:
            plt.scatter(x=1,y=np.mean(blockedMeans[r]),color=pltCol2,marker='o')
            plt.plot([0,1],[np.mean(availMeans[r]),np.mean(blockedMeans[r])],color='k',alpha=0.5,lw=1)
        except:
            continue
    #for r in [np.mean(rm) for rm in blockedRatMeans]:
    #    plt.scatter(x=1,y=np.mean(r),color=pltCol2,marker='o')
    print("blocked hex: ")
    sigBlocked = get_sigRats_fromMeanList(blockedMeans)
    print("avail hex: ")
    sigAvail = get_sigRats_fromMeanList(availMeans)
    #missingInd = [i for i in range(len(blockedRatMeans)) if len(blockedRatMeans[i])==0]
    if len(missingInd)>0:
        availMeans = np.delete(availMeans,missingInd)
    print("paired test")
    sigPaired = get_sigRatsPaired_from2samples(availMeans,blockedMeans,"greater")
    plot_sigMarkers(sigPaired,0.5,2.4)
    plot_sigMarkers(sigAvail,-.05,2.1)
    plot_sigMarkers(sigBlocked,1,2.1)
    plt.tight_layout()
    return fig

In [76]:
fig = plot_meanRatDaChangeAfterHexEntry(photrats,enteredInds,ignoredInds)
plt.xticks([0,1],["entered\navail","ignored\navail"])
plt.tight_layout()
fig.savefig(photrats.directory_prefix+"DaChangeEnteredVsIgnoredNewlyAvailBarplot.pdf")

blocked hex: 
p-value =  0.0546875
avail hex: 
p-value =  0.001953125
paired test
p-value =  0.23046875


- calculate difference in average rat DA trace (peak DA within 0.5s after hex entry - DA 1s before hex entry)

In [606]:
photrats.set_plot_trace("green_z_scored")

In [607]:
fig,availratmeans,avail_n_PerRat = plot_ratTracesAtHexChangeDiscovery(photrats,np.concatenate([enteredInds,ignoredInds]),pltCol="deeppink")
#fig.savefig(photrats.directory_prefix+"DaNewlyAvail_byRat.pdf")
fig,blockedratmeans,blocked_n_PerRat = plot_ratTracesAtHexChangeDiscovery(photrats,blockedHexAdjInds)
#fig.savefig(photrats.directory_prefix+"DaNewlyBlocked_byRat.pdf")
with open(photrats.directory_prefix+"availPathNumbers.txt", 'w') as f:
        f.write(str(avail_n_PerRat))
with open(photrats.directory_prefix+"blockedPathNumbers.txt", 'w') as f:
        f.write(str(blocked_n_PerRat))
fig = plotFirstEntryHexChangeMeanOverRats(photrats,availratmeans,blockedratmeans,ls2=':')
#fig.savefig(photrats.directory_prefix+"DaNewlyAvailvBlocked_byRat.pdf")


In [611]:
np.mean(availratmeans,axis=1)

array([0.0955 , 0.501  , 0.1608 , 0.2825 , 0.10706, 0.2976 , 0.6567 ,
       0.4949 , 0.2069 , 0.67   ], dtype=float16)

In [626]:
availMeans

[1.442, 4.48, 0.889, 3.176, 1.177, 2.91, 3.246, 1.249, 2.23, 1.72]

In [629]:
wilcoxon(availMeans,blockedMeans)#,alternative="greater")

WilcoxonResult(statistic=0.0, pvalue=0.001953125)

In [165]:
fig = plotFirstEntryHexChangeMeanOverRats(photrats,availratmeans,blockedratmeans,ls2=':')

In [194]:
def get_ratTracesAtHexChangeDiscovery(photrats,xvals,indices,plotTraces=True):
    tracePeakRats = photrats.df.loc[indices,"rat"].astype(str).values
    ratmeans = []
    n_PerRat = {r:[] for r in photrats.df.rat.unique()}
    for rat in photrats.df.rat.unique():
        tracesPost = get_TracesAroundIndex(photrats,indices[tracePeakRats==rat])
        tracePost = np.mean(tracesPost,axis=0)
        n_PerRat[rat] = len(tracesPost)
        if len(tracesPost)>=3:
            ratmeans.append(tracePost)
            if plotTraces:
                plt.plot(xvals,tracePost,color='k',alpha=0.3,lw=1)
    return ratmeans,n_PerRat

def plot_ratTracesAtHexChangeDiscovery(photrats,inds,pltCol='k'):
    '''Plot individual rat averages at discovery of newly available and newly blocked paths.
    Return average of rat average traces.'''
    xvals = np.arange(photrats.plot_window[0]*photrats.fs,photrats.plot_window[1]*photrats.fs+1)/photrats.fs
    fig = plt.figure()
    tracePeakRats = photrats.df.loc[inds,"rat"].astype(str).values
    ratmeans,n_PerRat = get_ratTracesAtHexChangeDiscovery(photrats,xvals,inds)
    plt.plot(xvals,np.mean(ratmeans,axis=0),color=pltCol)
    #plt.ylim(-1.5,4.9)
    plt.ylabel("mean DA")
    plt.xlabel("time from port entry (s)")
    plt.tight_layout()
    return fig,ratmeans,n_PerRat

In [174]:
ratsWithOkNum = [rat for rat in ignored_n_PerRat.keys() if ignored_n_PerRat[rat]>=3]

['IM-1272', 'IM-1276', 'IM-1292', 'IM-1322', 'IM-1478']

In [192]:
ignoredratmeans

[array([-0.3533 , -0.3564 , -0.3528 , ..., -0.087  , -0.0444 ,  0.03134],
       dtype=float16),
 array([ 0.2947 ,  0.225  ,  0.1694 , ..., -0.302  , -0.1936 , -0.09625],
       dtype=float16),
 array([0.153  , 0.1316 , 0.11334, ..., 0.00236, 0.04196, 0.07404],
       dtype=float16),
 array([-0.10345, -0.1081 , -0.1132 , ..., -0.2113 , -0.2081 , -0.2112 ],
       dtype=float16)]

In [195]:
fig,enteredratmeans,entered_n_PerRat = plot_ratTracesAtHexChangeDiscovery(photrats,enteredInds,pltCol="#27aeef")
fig.savefig(photrats.directory_prefix+"DaEnteredPath_byRat.pdf")
fig,ignoredratmeans,ignored_n_PerRat = plot_ratTracesAtHexChangeDiscovery(photrats,ignoredInds,pltCol="#b33dc6")
fig.savefig(photrats.directory_prefix+"DaIgnoredPath_byRat.pdf")
with open(photrats.directory_prefix+"enteredPathNumbers.txt", 'w') as f:
        f.write(str(entered_n_PerRat))
with open(photrats.directory_prefix+"ignoredPathNumbers.txt", 'w') as f:
        f.write(str(ignored_n_PerRat))
fig = plotFirstEntryHexChangeMeanOverRats(photrats,enteredratmeans,ignoredratmeans,pltCol1="#27aeef",pltCol2="#b33dc6")
fig.savefig(photrats.directory_prefix+"DaNewlyEnteredvIgnored_byRat.pdf")

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [198]:
entered_n_PerRat

{'IM-1272': 11,
 'IM-1273': 3,
 'IM-1276': 13,
 'IM-1291': 8,
 'IM-1292': 11,
 'IM-1322': 9,
 'IM-1398': 6,
 'IM-1434': 5,
 'IM-1478': 5,
 'IM-1532': 3}

In [197]:
np.sum([entered_n_PerRat[rat] for rat in entered_n_PerRat])

74

In [200]:
np.sum([ignored_n_PerRat[rat] for rat in ignored_n_PerRat if ignored_n_PerRat[rat]>2])

26

In [None]:
np.where(ignored_n_PerRat>)

In [None]:
fig = plotFirstEntryHexChangeMeanOverRats(photrats,enteredratmeans,ignoredratmeans,pltCol1="#27aeef",pltCol2="#b33dc6")

- CONTROL: plot avg da in all entries into newly available/blocked hex state on all other trials in session.

In [59]:
# identify the hex of first discovery

# identify all times during the same session when rat entered. (get indices)

# get traces aligned to these indices

# plot average response to same hexes

- instead, first plot DA at entry into all choice points.. then average response upon entry into all hexes (with distance from port >4)

In [63]:
del photrats.sesh_tmats

In [66]:
photrats.get_allChoicePoints()

## plot regression results

In [29]:
regCoefs = pd.read_csv(photrats.directory_prefix+chooseLregCoefsByRat.csv")
regSum = pd.read_csv(photrats.directory_prefix+chooseLregSummary.csv")

In [42]:
regSum

Unnamed: 0.1,Unnamed: 0,Estimate,Std. Error,z value,Pr(>|z|)
0,(Intercept),2.185913,0.333764,6.549282,5.78142e-11
1,pRwdScaled,1.685901,0.151973,11.093423,1.3502070000000001e-28
2,ldifScaled,-6.839847,0.595627,-11.483439,1.597954e-30


In [48]:
fig = plt.figure(figsize=(5,7))
xvals = ["Intercept","p(reward)","Distance"]
#sns.boxplot(data=regCoefs.loc[:,["pRwdScaled","ldifScaled"]],color='lightgrey')#,alpha=.5)
sns.barplot(data=regCoefs.loc[:,["(Intercept)","pRwdScaled","ldifScaled"]],color='lightgrey',ci=95)#,alpha=.5)
#plt.bar([0,1,2],regSum.Estimate.values)
sns.stripplot(data=regCoefs.loc[:,["(Intercept)","pRwdScaled","ldifScaled"]],color='k',size=8,marker='D',alpha=.9)
plt.axhline(y=0,ls='--',color='k')
plt.xticks([0,1,2],xvals)
#plotRegSigLevel(regSum,0,2.0)
#plotRegSigLevel(regSum,1,2.0)
#plt.ylim(-10,2.2)
plt.ylabel("Path choice ß value\n(mixed-effects)",fontsize='xx-large')
plt.tight_layout()

In [49]:
fig.savefig(photrats.directory_prefix+"mixedEffectsFullResults.pdf")

# plot single trial examples of maze configurations and tick plots.
- plot position over maze early vs late in block

In [409]:
photrats.df.loc[photrats.df.session_type=='barrier','session'].unique()

array([  5,   6,   7,   0,   1,  17,   9,  11,  28,  29,  30,  20,  21,
        22,  32,  35,  36,  37,  43,  44,  45,  46,  47,  48,  55,  56,
        57,  58,  59,  60,  61,  64,  65,  68,  69,  72,  73,  75,  76,
        89,  92,  93,  98, 100], dtype=int8)

In [39]:
hexlist = [2,47,46,45,44,43,3,\
49,42,41,40,39,48,\
38,37,36,35,34,33,\
32,31,30,29,28,\
27,26,25,24,23,\
22,21,20,19,\
18,17,16,15,\
14,13,12,\
11,10,9,\
8,7,\
6,5,\
4,\
1]
#hexlist = np.subtract(hexlist,1) #convert to index-based states
scaleFactor = [(490)/13,365/18]
#scaleFactor = [(460)/13,350/18]
#scaleFactor = [(550-70)/13,380/18]
coords = []
cols = [7,6,6,5,5,4,4,3,3,2,2,1,1]
maxrows = 13
r = 0
x = 1
y = 18
startr = 1
while r < maxrows:
    maxcols = cols[r]
    c = 0
    if r%2!=0:
        startr+=1
    x=startr
    while c < maxcols:
        coords.append([x*scaleFactor[0]+35,y*scaleFactor[1]+55])
        #coords.append([x*scaleFactor[0]+70,y*scaleFactor[1]+55])
        #coords.append([x*scaleFactor[0]+70,y*scaleFactor[1]+40])
        x += 2
        c += 1
    if r%2!=0:
        y -= 2
    else:
        y-=1
    r += 1
cents = {h: c for h,c in zip(hexlist,coords)}
centdf = pd.DataFrame(cents)
centdf = centdf.T

In [26]:
from photometrySessionVizualizations import *

In [247]:
photrats.directory_prefix = loadpath


In [95]:
import matplotlib.colors as mc

In [436]:
def plot_hex_outline(barriers):
    bardf = centdf.drop(barriers,axis=0)
    plt.scatter(bardf.loc[:,0].values,bardf.loc[:,1].values,c=\
            'none',marker='H',s=1000,edgecolors="k",alpha=1,lw=2)
    plt.scatter(bardf.loc[:,0].values,bardf.loc[:,1].values,c=\
            'darkgrey',marker='H',s=1000,edgecolors="none",alpha=0.4)
    
def plot_hex_outline(barriers):
    bardf = centdf.drop(barriers,axis=0)
    plt.scatter(bardf.loc[:,0].values,bardf.loc[:,1].values,c=\
            'darkgrey',marker='H',s=1000,edgecolors="k",alpha=1,lw=2)

In [545]:
from scipy.stats import gaussian_kde
from scipy.ndimage import gaussian_filter

def plot_posOverlayAndTickPlot(photrats,s,blks = [1,2],
        posColor='cyan',saveFig=True,secondHalf=False,plot_ticks=False,
        edgCol='k', plotOverlay=True,plotProbs=True,trans=0.2,density=False,vmax=60,vmin=5):
    dat = photrats.df.loc[photrats.df.session==s].copy()
    halfString = '_only2ndHalfBlkPos' if secondHalf else ''
    densString = "_density" if density else ""
    fig = plt.figure(figsize=(15.5,9))
    startrows = 10
    if plot_ticks:
        plot_sesh_pathChoices_ticks(dat,blocks=blks,startrows=startrows)
    plt.subplots_adjust(wspace=0, hspace=0,top=0.95)
    #plt.figure(figsize=(17.4,6))
    for blk in blks:#dat.block.unique():
        prwds = dat.loc[dat.block==blk,["nom_rwd_a","nom_rwd_b",\
        "nom_rwd_c"]].values[0].astype(int)
        if dat.session_type.values[0] == "barrier":
            bars = np.add(photrats.sesh_barIDs[s][blk-1],1)
        else:
            bars = np.add(photrats.sesh_barIDs[s],1)
        ax = plt.subplot2grid((6+startrows,len(blks)*10),(0,(blk-np.min(blks))*10),\
            colspan = 8, rowspan =startrows)
        #plot_hex_outline(bars)
        #plot_hex_barriers(bars)
        #plt.ylim(0,522)
        #plt.xlim(55,590)
        #ax = plt.gca()
        #ax.invert_yaxis()
        #plt.tight_layout()
        #plt.xticks([])
        #plt.yticks([])
        #ax = plt.subplot(1,3,blk-np.min(blks)+1)
        minTri = 25 if secondHalf else 0 
        if plotOverlay:
            x = dat.loc[(dat.block==blk)&(dat.tri>minTri)&(dat.x.notnull()),'x'].values
            y = dat.loc[(dat.block==blk)&(dat.tri>minTri)&(dat.x.notnull()),'y'].values
            if density:
                H, xedges, yedges = np.histogram2d(x,y,bins=40)#,range=([[0,640],[0,590]]))
                X, Y = np.meshgrid(xedges, yedges)
                plt.pcolormesh(X,Y,H.T,cmap="Greys",norm=mc.Normalize(vmin=vmin,vmax=vmax))
                #xy = np.vstack([x,y])
                #z = gaussian_kde(xy)(xy)
                #plt.scatter(x,y,c=z,s=20)
                #plt.hist2d(x, y, bins=(100,100), cmap=plt.cm.BuPu,alpha=0.6,edgecolors='none',density=True,vmax=0.0001)
                #ax = plt.gca()
                #ax.invert_yaxis()
                #plt.hexbin(x, y, gridsize=40,alpha=1,cmap="Greys",norm=mc.Normalize(vmin=vmin,vmax=vmax),edgecolors='none',reduce_C_function=np.mean)
                #plt.colorbar()
                plot_hex_outline(bars)
                plot_hex_barriers(bars)
                plt.ylim(0,522)
                plt.xlim(55,590)
                ax = plt.gca()
                ax.invert_yaxis()
                plt.tight_layout()
                plt.xticks([])
                plt.yticks([])
            else:
                plot_hex_outline(bars)
                plot_hex_barriers(bars)
                plt.ylim(0,522)
                plt.xlim(55,590)
                ax = plt.gca()
                ax.invert_yaxis()
                plt.tight_layout()
                plt.xticks([])
                plt.yticks([])
                xy = np.vstack([x,y])
                z = gaussian_kde(xy)(xy)
                plt.scatter(x,y,c=z,s=20,cmap="viridis")#"cividis")
                #plt.scatter(x,y,alpha=trans,color=posColor,\
                #        s=20,edgecolors=edgCol,lw=2)#"darkviolet")#"salmon")#"fuchsia"
        if plotProbs:
            plt.text(x=300,y=0,s=str(prwds[0])+"%",fontsize=30,\
                fontweight='bold',backgroundcolor="k",color="white")#backgroundcolor="darkblue",color="white")
            plt.text(x=30,y=470,s=str(prwds[1])+"%",fontsize=30,\
                fontweight='bold',backgroundcolor="k",color="white")#backgroundcolor="darkorange",color="white")
            plt.text(x=570,y=470,s=str(prwds[2])+"%",fontsize=30,\
                fontweight='bold',backgroundcolor="k",color="white")
        plt.axis("off")
    if saveFig:
        fig.savefig(photrats.directory_prefix+"sesh"+\
            str(s)+"blocks"+str(blks)+"positionOverlayPlot"+halfString+densString+".pdf")
        fig.savefig(photrats.directory_prefix+"sesh"+\
            str(s)+"blocks"+str(blks)+"positionOverlayPlot"+halfString+densString+".png")


In [593]:
def plot_10tri_posOverlay(photrats,s,blks = [1,2],
        posColor='cyan',saveFig=True,groupOfTen=1,
        edgCol='k', plotOverlay=True,plotProbs=True):
    dat = photrats.df.loc[photrats.df.session==s].copy()
    triString = str(groupOfTen)
    fig = plt.figure(figsize=(15.5,9))
    startrows = 10
    plt.subplots_adjust(wspace=0, hspace=0,top=0.95)
    #plt.figure(figsize=(17.4,6))
    for blk in blks:#dat.block.unique():
        prwds = dat.loc[dat.block==blk,["nom_rwd_a","nom_rwd_b",\
        "nom_rwd_c"]].values[0].astype(int)
        if dat.session_type.values[0] == "barrier":
            bars = np.add(photrats.sesh_barIDs[s][blk-1],1)
        else:
            bars = np.add(photrats.sesh_barIDs[s],1)
        
        #fig,ax = plt.subplots(figsize=(5.6,5.5))#figsize=(7.5,5))
        ax = plt.subplot2grid((6+startrows,len(blks)*10),(0,(blk-np.min(blks))*10),\
            colspan = 8, rowspan =startrows)
        #ax = plt.subplot(1,3,blk-np.min(blks)+1)
        plot_hex_outline(bars)
        plot_hex_barriers(bars)
        plt.ylim(0,522)
        plt.xlim(55,590)
        ax = plt.gca()
        ax.invert_yaxis()
        plt.tight_layout()
        plt.xticks([])
        plt.yticks([])
        minTri = (groupOfTen-1)*10
        maxTri = groupOfTen*10
        if plotOverlay:
            x = dat.loc[(dat.block==blk)&(dat.tri>minTri)&(dat.x.notnull()),'x'].values
            y = dat.loc[(dat.block==blk)&(dat.tri>minTri)&(dat.x.notnull()),'y'].values
            #plt.scatter(dat.loc[(dat.block==blk)&(dat.tri>minTri)&(dat.tri<=maxTri),'x'].values,\
            #            dat.loc[(dat.block==blk)&(dat.tri>minTri)&(dat.tri<=maxTri),'y'].values,\
            #            alpha=0.4,color=posColor,s=40)#,edgecolors=edgCol,lw=2)#"darkviolet")#"salmon")#"fuchsia"
            xy = np.vstack([x,y])
            z = gaussian_kde(xy)(xy)
            plt.scatter(x,y,c=z,s=20,cmap="viridis")
        if plotProbs:
            plt.text(x=300,y=0,s=str(prwds[0])+"%",fontsize=30,\
                fontweight='bold',backgroundcolor="k",color="white")#backgroundcolor="darkblue",color="white")
            plt.text(x=30,y=470,s=str(prwds[1])+"%",fontsize=30,\
                fontweight='bold',backgroundcolor="k",color="white")#backgroundcolor="darkorange",color="white")
            plt.text(x=570,y=470,s=str(prwds[2])+"%",fontsize=30,\
                fontweight='bold',backgroundcolor="k",color="white")#backgroundcolor="darkgreen",color="white")
        plt.axis("off")
    if saveFig:
        fig.savefig(photrats.directory_prefix+"sesh"+\
            str(s)+"blocks"+str(blks)+"positionOverlayPlot"+str(triString)+"_10tris.pdf")
        fig.savefig(photrats.directory_prefix+"sesh"+\
            str(s)+"blocks"+str(blks)+"positionOverlayPlot"+str(triString)+"_10tris.png")

- create acceleration column from velocity

In [584]:
photrats.df.loc[photrats.df.vel.notnull(),"acc"] = photrats.df.loc[photrats.df.vel.notnull(),'vel'].diff()
photrats.df.loc[:,"acc"] = photrats.df.loc[:,"acc"].fillna(method="ffill")

In [583]:
plot_posOverlayAndTickPlot(photrats,s=99,posColor='cyan',edgCol='none',saveFig=True,\
                           secondHalf=True,plotOverlay=True,plotProbs=True,trans=0.1,vmin=1,vmax=8,density=False,blks=[1,2])

  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()
  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()


In [201]:
photrats.df.loc[photrats.df.session.diff()!=0,["session_type","session","rat"]].values

array([['prob', 2, 'IM-1272'],
       ['prob', 3, 'IM-1272'],
       ['prob', 4, 'IM-1272'],
       ['prob', 12, 'IM-1273'],
       ['prob', 15, 'IM-1273'],
       ['prob', 23, 'IM-1276'],
       ['prob', 24, 'IM-1276'],
       ['prob', 25, 'IM-1276'],
       ['prob', 26, 'IM-1276'],
       ['prob', 27, 'IM-1276'],
       ['prob', 39, 'IM-1291'],
       ['prob', 40, 'IM-1291'],
       ['prob', 38, 'IM-1291'],
       ['prob', 41, 'IM-1291'],
       ['prob', 42, 'IM-1291'],
       ['prob', 52, 'IM-1292'],
       ['prob', 53, 'IM-1292'],
       ['prob', 54, 'IM-1292'],
       ['barrier', 5, 'IM-1272'],
       ['barrier', 6, 'IM-1272'],
       ['barrier', 7, 'IM-1272'],
       ['barrier', 0, 'IM-1272'],
       ['barrier', 1, 'IM-1272'],
       ['barrier', 17, 'IM-1273'],
       ['barrier', 9, 'IM-1273'],
       ['barrier', 11, 'IM-1273'],
       ['barrier', 28, 'IM-1276'],
       ['barrier', 29, 'IM-1276'],
       ['barrier', 30, 'IM-1276'],
       ['barrier', 20, 'IM-1276'],
       ['barr

In [168]:
sesh = 27
#plot_posOverlayAndTickPlot(photrats,s=69,posColor='dodgerblue',edgCol='k',saveFig=True,secondHalf=False,plotOverlay=True,plotProbs=True)
plot_posOverlayAndTickPlot(photrats,s=sesh,posColor='cyan',edgCol='blue',saveFig=True,secondHalf=True,plotOverlay=True,plotProbs=True)
plot_posOverlayAndTickPlot(photrats,s=sesh,posColor='cyan',edgCol='blue',saveFig=True,secondHalf=False,plotOverlay=True,plotProbs=True)
#plot_posOverlayAndTickPlot(photrats,s=69,posColor='dodgerblue',edgCol='k',saveFig=True,secondHalf=True,plotOverlay=True,plotProbs=True)
#plot_posOverlayAndTickPlot(photrats,s=36,posColor='dodgerblue',edgCol='k',saveFig=True,secondHalf=True,plotOverlay=True,plotProbs=True)

  ax3.set_yticklabels([''])
  ax2.set_yticklabels([''])
  ax3.set_yticklabels([''])
  plt.tight_layout()
  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()
  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()
  ax3.set_yticklabels([''])
  ax2.set_yticklabels([''])
  ax3.set_yticklabels([''])
  plt.tight_layout()
  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()
  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()


In [91]:
plt.close("all")

In [595]:
sesh = 90
blocks = [2,3]
plot_10tri_posOverlay(photrats,s=sesh,posColor='midnightblue',edgCol='none',saveFig=True,groupOfTen=1,plotOverlay=True,plotProbs=True,blks=blocks)
plot_10tri_posOverlay(photrats,s=sesh,posColor='midnightblue',edgCol='none',saveFig=True,groupOfTen=2,plotOverlay=True,plotProbs=True,blks=blocks)
plot_10tri_posOverlay(photrats,s=sesh,posColor='midnightblue',edgCol='none',saveFig=True,groupOfTen=3,plotOverlay=True,plotProbs=True,blks=blocks)
#plot_10tri_posOverlay(photrats,s=sesh,posColor='cyan',edgCol='blue',saveFig=True,groupOfTen=4,plotOverlay=True,plotProbs=True)

  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()
  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()
  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()
  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()
  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()
  plt.scatter(centdf.loc[barriers,0].values,centdf.loc[barriers,1].values,c=\
  plt.tight_layout()


In [89]:
for sesh in photrats.df.loc[photrats.df.session_type=="barrier","session"].unique()[:10]:
    plot_10tri_posOverlayAndTickPlot(photrats,s=sesh,posColor='cyan',edgCol='blue',saveFig=True,secondTen=False,plotOverlay=True,plotProbs=True)
    plot_10tri_posOverlayAndTickPlot(photrats,s=sesh,posColor='cyan',edgCol='blue',saveFig=True,secondTen=True,plotOverlay=True,plotProbs=True)

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
