# Ratterdam - Place Field Repetition - Field Dynamics Exploration
### WH Mid August 2020
### Goal is to explore the variables contributing to PF repetition and specifically temporal field dynamics

In [2]:
import ratterdam_CoreDataStructures as Core
import ratterdam_ParseBehavior as Parse
import numpy as np
from importlib import reload
from scipy.stats import sem
import utility_fx as util
import os
import matplotlib.gridspec as gridspec
from matplotlib import pyplot as plt
import ratterdam_Defaults as Def
import ratterdam_visBasic as Vis
import RateMapClass_William_20190308 as RateMapClass
# import repeatingPC
# import placeFieldBorders
import ratterdam_RepetitionCoreFx as RepCore
import williamDefaults as wmDef
from matplotlib.backends.backend_pdf import PdfPages
import more_itertools, itertools
from sklearn.metrics import auc
import alphashape
from descartes import PolygonPatch
from scipy.interpolate import splrep, splev
from scipy.spatial import ConvexHull
import scipy

In [3]:
%qtconsole --style native
%matplotlib qt5

In [4]:
rat = "R859"
day = "D3"
savepath = f'E:\\Ratterdam\\{rat}\\ratterdam_plots\\{day}\\'
df = f'E:\Ratterdam\\{rat}\\{rat}_RatterdamOpen_{day}\\'
clust = 'TT2\\cl-maze1.1'
unit = RepCore.loadRepeatingUnit(df, clust)

FileNotFoundError: [Errno 2] No such file or directory: 'E:\\Ratterdam\\R859\\R859_RatterdamOpen_D3\\TT2\\cl-maze1.1'

## Visit Triggered Average
#### For each visit in a field, look back in a window. Compute pearson r on field segment and all others, pairwise. 
#### Shuffle to test Ho: stat near 0, Ha: neg shifted.

In [449]:
fieldArray = [i for i in unit.fields]

In [450]:
# create spline rep of the fields
s=1
k=3 # should be 3 usually
fieldFx = [splrep(d[:,0], d[:,1], k=k, task=0, s=s) for d in fieldArray]
fmax = int(np.ceil(max([max(field[:,0]) for field in fieldArray])))
fmin = int(np.ceil(min([min(field[:,0]) for field in fieldArray])))

# create spline rep of the fields
s=1
k=3 # should be 3 usually
diffs = [np.column_stack((i[:,0],np.gradient(i[:,1]))) for i in unit.fields]
diffFx = [splrep(d[:,0], d[:,1], k=k, task=0, s=s) for d in diffs]

In [323]:
def getTraceInWindow(fieldFx,fieldID,visitTs,winsize=1e6*5*60, numPts=75):
    """
    Input:  fieldFx - list of spline reps that each represent a field in unit.fields (in same order)
            fieldID - index of which field to use in fieldFx. Again, same order as in unit.fields.
            visitTs - NL ts of a visit to a field
            winSize - time to look back in us. Def is 5 min
            numPts - number of points in the spline eval. Def is 75
    Create a window looking back from the visit in time. Evaluate the spline
    in that window and return.
    Return: (n,) array of evaluted spline segment
    """
    begin = visitTs-winsize
    if begin <  0:
        begin = 0
    x = np.linspace(begin,visitTs,numPts)
    sw = splev(x,fieldFx[fieldID])
    return sw

def rotateFields(unit):
    """
    Input   - unit: Unit class object
            
    selects a temporal shift for each field. Shift is rotational such that 
    whatever 'falls off' end of vector 'comes back' to the front.
    
    Returns - GuttedField object with shifted fields in gunit.fields attribute
    """
    gunit = GuttedUnit()
    sfields = []
    for field in unit.fields:
        nvisits = field.shape[0]
        shift = np.random.randint(0,nvisits)
        sf = np.column_stack((field[:,0], np.roll(field[:,1], shift)))
        sfields.append(sf)
    gunit.fields = sfields
    return gunit

def corrTraces(a,b):
    """
    Runs pearson R on inputs A,B returns the coeff.
    """
    return scipy.stats.pearsonr(a,b)[0]

In [483]:
# Routine to look at each visit and perform a corr on the lagging window btwen it and each other field
nf = len(unit.fields)
pairs = [f"{i[0]}{i[1]}" for i in list(itertools.product(range(nf),range(nf)))]
traces = {i:[] for i in pairs} 
corrs = {i:[] for i in pairs}

for i,field in enumerate(unit.fields):
    for visit in field:
        ts = visit[0]
        if ts > fmin+1e6*60*5:
            for j,otherfield in enumerate(unit.fields):
                if i!=j:
                    segI = getTraceInWindow(diffFx, i, ts)
                    segJ = getTraceInWindow(diffFx, j, ts)
                    traces[f"{i}{j}"].append(np.argmax(scipy.signal.correlate(segI, segJ,mode='full')))
                    corr = corrTraces(segI,segJ)
                    corrs[f"{i}{j}"].append(corr)

In [485]:
x0 = [i for i in unit.fields[0][:,0] if i > fmin+1e6*60*5]
x1 = [i for i in unit.fields[1][:,0] if i > fmin+1e6*60*5]
x2 = [i for i in unit.fields[2][:,0] if i > fmin+1e6*60*5]
xvs = {"0":x0, "1":x1, "2":x2}

In [488]:
plt.figure()
pair = "12"
plt.title(pair+" Grad")
plt.plot(xvs[pair[0]], traces[pair], linestyle='-',marker='.')
pair = "21"
plt.plot(xvs[pair[0]], traces[pair], linestyle='-',marker='.')

[<matplotlib.lines.Line2D at 0x1a78de84ac8>]