# Ratterdam - Place Field Repetition - Field Dynamics Exploration
### WH Mid August 2020
### Goal is to explore the variables contributing to PF repetition and specifically temporal field dynamics

In [5]:
import ratterdam_CoreDataStructures as Core
import ratterdam_ParseBehavior as Parse
import numpy as np
from importlib import reload
from scipy.stats import sem
import utility_fx as util
import os
import matplotlib.gridspec as gridspec
from matplotlib import pyplot as plt
import ratterdam_Defaults as Def
import ratterdam_visBasic as Vis
import RateMapClass_William_20190308 as RateMapClass
import ratterdam_RepetitionCoreFx as RepCore
import williamDefaults as wmDef
from matplotlib.backends.backend_pdf import PdfPages
import more_itertools, itertools
from sklearn.metrics import auc
import alphashape
from descartes import PolygonPatch
from scipy.interpolate import splrep, splev
from scipy.spatial import ConvexHull
import scipy

In [3]:
%qtconsole --style native
%matplotlib qt5

In [149]:
rat = "R859"
day = "D3"
savepath = f'E:\\Ratterdam\\{rat}\\ratterdam_plots\\{day}\\'
df = f'E:\Ratterdam\\{rat}\\{rat}_RatterdamOpen_{day}\\'
clust = 'TT1_0001\\cl-maze1.1'
unit = RepCore.loadRepeatingUnit(df, clust)

  aboveThreshold = np.where(rateMap >= max(fieldThreshold,0), True, False)
C:\Users\whockei1\AppData\Roaming\Python\Python36\site-packages\skimage\morphology\_deprecated.py:5: skimage_deprecation: Function ``watershed`` is deprecated and will be removed in version 0.19. Use ``skimage.segmentation.watershed`` instead.
  def watershed(image, markers=None, connectivity=1, offset=None, mask=None,
C:\Users\whockei1\AppData\Roaming\Python\Python36\site-packages\skimage\morphology\_deprecated.py:5: skimage_deprecation: Function ``watershed`` is deprecated and will be removed in version 0.19. Use ``skimage.segmentation.watershed`` instead.
  def watershed(image, markers=None, connectivity=1, offset=None, mask=None,
C:\Users\whockei1\AppData\Roaming\Python\Python36\site-packages\skimage\morphology\_deprecated.py:5: skimage_deprecation: Function ``watershed`` is deprecated and will be removed in version 0.19. Use ``skimage.segmentation.watershed`` instead.
  def watershed(image, markers=None, co

1.4774651645812056
0.26971992137828066
0.7272776122077441
0.04595286884490898
0.0759418890819691
0.003553151880919795
0.4807152384711971
0.009616525717953465


In [44]:
# set parameters and choose which field to analyze
def find_rateRegime_shift(field, winSz=5, look_ahead=5, thresh=2):
    """
    Take a sliding window of size winSz (always steps by 1).
    Look for a rate change in the next element (i.e. first element of next window_n+1)
    that is excessive of mean(window_n) +/- thresh*std(window_n). If so see if it persists for look_ahead num of visits
    Remove duplicates (i.e. as you slide the window by 1 unit and there's a change youll
    catch it multiple times) by having the changes by at least a window length apart.
    Return the indices of the change locations
    """
    rate_changes = []

    for i in range(len(field)):

        window = field[i:i+winSz] # sliding window. serves as baseline rate to compare future putative rate changes to
        if i < (len(field)-(winSz+look_ahead)):

            #straightforward - twosided threshold +/- 2*std of the reference window
            threshUp, threshDown = np.mean(window) + thresh*np.std(window), np.mean(window) - thresh*np.std(window)
            if threshDown < 0:
                threshDown = 0 # can't have a negative fr 

            if (field[i+winSz+1][1] > threshUp) or (field[i+winSz+1][1] < threshDown):

                # keep track here of whether candidate visit is above or below thresh. B/c youre
                # looking for this pattern to persist so need to know direction
                if field[i+winSz+1][1] > threshUp:
                    compfx = np.greater
                    thresh = threshUp
                elif field[i+winSz+1][1] < threshDown:
                    compfx = np.less
                    thresh = threshDown

                # walk ahead a certain number of trials and see if the threshold crossing persists
                for j in range(look_ahead):
                    if ~compfx(field[i+winSz+1+j][1], thresh): # if it doesnt, break out
                        break
                if j == look_ahead-1: # meaning you got to the end all visits were more extreme than thresh
                    if rate_changes != []:
                        # basically because we're using a sliding window, a rate change event 
                        # can be counted multiple times as the window slides. so make sure
                        # two rate change events are spaced by at least the winSz meaning youre not doublecounting
                        if np.abs(rate_changes[-1]-i) < winSz:
                            pass
                        else:
                            rate_changes.append(i)
                    elif rate_changes == []:
                        rate_changes.append(i)
    return rate_changes

In [145]:
def findInflections(field):
    """
    Find all inflection pts of the input 1d (n,1) array
    array ([x,y]...)
    I.e. where there is a change in 2nd deriv
    """
    signdelta = np.sign(np.diff(field, axis=0))
    idx = np.where(np.diff(signdelta,axis=0)!=0)[0] # by convention sign returns +/-1 so if you think about it two diff values will be +/2 and two same values in a row will be zero
    idx = idx + 1 # the above catches the point right before the inflection
    idx = np.hstack((0,idx, field.shape[0]-1))
    return idx

def length_filter(idx,data,thresh=2):
    """Filter the visits that have been flagged as rate change
    epoch boundaries by whether the epoch is long enough. Ie. if
    there are at least thresh visits *between* the successive epoch
    bounds"""
    
    #unpacking below: you diff your input idx list and look for diffs that are big enough
    # bc diffing the last element is impossible only get original idx vals up to not incl the end
    passidx = idx[:-1][np.diff(idx,axis=0)>=thresh+1]
    passidx = np.hstack((passidx, data.shape[0]-1)) #add this back on
    passidx = passidx[:,np.newaxis]
    return passidx

def magnitude_filter(idx,data, thresh=0.20):
    """
    Filter potential epochs according to whether or not
    the beginning to end difference (idx[n],idx[n+1]) exceeds
    thresh % of the range of all the data. Data is 2d (n,2) array
    """
    passidx = np.empty((0,1),dtype=np.int)
    for i in range(0,idx.shape[0]-1):
        start, stop = data[idx[i]][0][1], data[idx[i+1]][0][1] # middle zero bc its nested down an arr level
        pctRange = abs(start-stop)/(data[:,1].max() - data[:,1].min())
        if pctRange >= thresh:
            passidx = np.vstack((passidx, i))
    passidx = np.append(idx[passidx], data.shape[0]-1) #add this back on 
    return passidx

def findRateEpochTs(field):
    idx = findInflections(field)
    idx = length_filter(idx, field, thresh=2)
    idx = magnitude_filter(idx, data, thresh=0.2)
    ts = field[idx,0]
    return ts