In [1]:
import numpy as np, matplotlib.pyplot as plt, random, json, pickle, datetime, copy, socket, math
from scipy.stats import sem
import matplotlib.colors as colors
from scipy.ndimage import gaussian_filter as gauss # for smoothing ratemaps
import sys

if socket.gethostname() == 'Tolman':
    codeDirBase = 'C:\\Users\\whockei1\\Google Drive'
elif socket.gethostname() == 'DESKTOP-BECTOJ9':
    codeDirBase = 'C:\\Users\\whock\\Google Drive'
    
sys.path.insert(0, codeDirBase + '\\KnierimLab\\Ratterdam\\Code')
sys.path.insert(0, codeDirBase + '\\Python_Code\\KLab\\mts_analysis')
import utility_fx as util
import ratterdam_ParseBehavior as pBehav
from ratterdam_Defaults import *
import ratterdam_CoreDataStructures as core

In [2]:
datafile = "E:\\Ratterdam\\R765\\R765DFD4\\"
clust = "TT14\\cl-maze1.5"

In [3]:
%qtconsole --style native

In [86]:
daycode = "DFD4"
tt = "TT14"
clust = "cl-maze1.5"
clustName = tt + "\\" + clust
behav = core.BehavioralData(datafile,daycode)
ts, position, alleyTracking, alleyVisits,  txtVisits = behav.loadData()
unit = core.UnitData(clustName, datafile, daycode, alleyBounds, alleyVisits, txtVisits, position, ts)
unit.loadData_raw()

  n = util.weird_smooth(n,1)
  n = util.weird_smooth(n,1)
  Z=VV/WW
  n = util.weird_smooth(n,1)
  n = util.weird_smooth(n,1)
  W=0*U.copy()+1


In [26]:
def calcPermutes(nA,nB,nC):
    total = np.sum([nA, nB, nC])
    return math.factorial(total) / (math.factorial(nA) * math.factorial(nB) * math.factorial(nC))
    
for i in range(17):
    txts = txtVisits[i]
    nA = txts.count('A')
    nB = txts.count('B')
    nC = txts.count('C')
    if nA > 0 and nB > 0 and nC > 0:
        print(f"Alley {i}")
        print(f"{nA} A's")
        print(f"{nB} B's")
        print(f"{nC} C's")
        print(calcPermutes(nA, nB, nC))
        print("----------------------")

Alley 0
5 A's
9 B's
9 C's
1636014380.0
----------------------
Alley 1
7 A's
4 B's
7 C's
10501920.0
----------------------
Alley 2
2 A's
2 B's
10 C's
6006.0
----------------------
Alley 3
9 A's
4 B's
7 C's
55426800.0
----------------------
Alley 4
1 A's
8 B's
10 C's
831402.0
----------------------
Alley 5
3 A's
5 B's
18 C's
87487400.0
----------------------
Alley 6
3 A's
13 B's
5 C's
11395440.0
----------------------
Alley 8
9 A's
2 B's
20 C's
4656977325.0
----------------------
Alley 9
18 A's
7 B's
9 C's
25213318759200.0
----------------------
Alley 10
5 A's
4 B's
13 C's
62674920.0
----------------------
Alley 11
7 A's
5 B's
6 C's
14702688.0
----------------------
Alley 12
17 A's
3 B's
5 C's
60568200.0
----------------------
Alley 14
14 A's
5 B's
3 C's
17907120.0
----------------------
Alley 15
15 A's
3 B's
8 C's
1274816400.0
----------------------
Alley 16
3 A's
4 B's
11 C's
1113840.0
----------------------


In [2]:
def poolTrials(unit, alley, labels, txt):
    """
    Pool all trials that will form a group.
    Group defined as linear RM (computed differently from viz. lin rm)
    from all visits to a given alley when it harbored a given texture.
    
    This does not subsample to approx. balance group sizes. That is done after.
    
    Labels is a list of texture labels, either real or shuffled prior to this fx
    """
    rms = []
    idx = []
    visits = unit.alleys[alley]
    for i,visit in enumerate(visits):
        if labels[i] == txt:
            rm = visit['ratemap1d']
            if type(rm) == np.ndarray:
                rm = np.nan_to_num(rm)
                rms.append(rm)
                idx.append(i)
    rms = np.asarray(rms)
    return idx, rms

In [3]:
def computeTestStatistic_Diffs(groupX, groupY):
    """
    Takes two arrays. Each of which is a stack
    of single trial {RM or avg? decide}. 
    
    Avgs them to a summary trace and returns their bin-wise diff
    """
    avgX, avgY = np.nanmean(groupX, axis=0), np.nanmean(groupY, axis=0)
    return util.weird_smooth(avgX - avgY,2)

In [4]:
def getLabels(alley):
    """
    Get actual trial labels for a group
    Group defined as visits to a given txt at given alley
    """
    visits = unit.alleys[alley]
    labels = []
    for visit in visits:
        labels.append(visit['metadata']['stimulus'])
    return labels

In [5]:
def genSingleNullStat(unit, alley, txtX, txtY, labels):
    """
    Generate a single null test statistic (diff x-y here)
    Shuffle labels, recompute means and take diff. 1x
    """
    shuffLabels = np.random.permutation(labels)
    idxX, rmsX = poolTrials(unit, alley, shuffLabels, txtX)
    idxY, rmsY = poolTrials(unit, alley, shuffLabels, txtY)
    null = computeTestStatistic_Diffs(rmsX, rmsY)
    return null
    

In [6]:
def genRealStat(unit, alley, txtX, txtY):
    labels = getLabels(alley)
    idxX, rmsX = poolTrials(unit, alley, labels, txtX)
    idxY, rmsY = poolTrials(unit, alley, labels, txtY)
    stat = computeTestStatistic_Diffs(rmsX, rmsY)
    return stat

In [7]:
def computeBandThresh(nulls, alpha, side):
    '''Given a list of null array traces, find ordinate at 
    at each point that  admits a proportion of nulls equal to cutoff'''
    
    if side == 'upper':
        isReversed = True
    elif side == 'lower':
        isReversed = False
        
    propNull = int(((alpha / 2) * len(nulls)) + 1)
    datarange = range(len(nulls[0]))
    significanceBand = []
    for point in datarange:
        nullOrdinates = nulls[:,point]
        sortedVals = list(sorted(nullOrdinates, reverse=isReversed))
        significanceBand.append(sortedVals[propNull - 1]) #explicitly +1 to cutoff and -1 here to keep clear where thresh is and how 0idx works

    significanceBand = np.asarray(significanceBand)
    return significanceBand

In [8]:
def computeGlobalCrossings(nulls, lowerBand, upperBand):
    """
    Given an array of null test statistics, compute 
    the number of crossings *anywhere* given the supplied
    significance bands. Return proportion (obs. p-value)
    """
    
    passBools = [any(np.logical_or(probe > upperBand, probe < lowerBand)) for probe in nulls] # eg [T,F,F,T..etc]
    return sum(passBools)/len(passBools)


In [9]:
def global_FWER_alpha(nulls, alpha=0.05, fwerModifier=3*17):
    """
    Calculates the global, FWER corrected p-value at each bin of the data trace
    Returns the actual global P and the bands of test statistic ordinates that
    are the thresholds. 
    """
    FWERalphaSelected = None
    FWERalpha = (alpha / fwerModifier)  # nb this is a proportion (decimal) not a list cutoff (integer)
    alphaIncrements = np.linspace(0.01, 1e-10, 100) 
    fwerSatisfied = False
    for adjustedAlpha in alphaIncrements:
        if not fwerSatisfied:
            lowerBand, upperBand = computeBandThresh(nulls, adjustedAlpha, 'lower'), computeBandThresh(nulls, adjustedAlpha, 'upper')
            propCrossings = computeGlobalCrossings(nulls, lowerBand, upperBand)
            if propCrossings < FWERalpha: 
                fwerSatisfied = True
                FWERalphaSelected = adjustedAlpha
                globalLower, globalUpper = lowerBand, upperBand
    return FWERalphaSelected, lowerBand, upperBand

In [10]:
def genNNulls(n, unit, alley, txtX, txtY):
    """
    Generates n null test statistics, hard coded
    now to be the binwise diff of avg(txtA) - avg(txtB)
    Returns np array nXl where l is length of 1d RM in bins
    """
    nulls = np.empty((0,30))
    labels = getLabels(alley)
    for i in range(n):
        null = genSingleNullStat(unit, alley, txtX, txtY, labels)
        nulls = np.vstack((nulls, null))
    return nulls

In [11]:
def unitPermutationTest_SinglePair(unit, alley, txtX, txtY, nnulls, plot=True, returnInfo=True):
    """
    Wrapper function for global_FWER_alpha() that plots results
    """
    
    nulls = genNNulls(nnulls,unit,alley,txtX,txtY)
    FWERalphaSelected, glowerBand, gupperBand = global_FWER_alpha(nulls)
    stat = genRealStat(unit, alley, txtX, txtY)

    #Below, calculate the pw alpha bc significantly modulated regions are defined
    # as those that pass the global band somewhere but then their extent is defined
    # as the whole region where they pass the pointwise band. See Buzsaki paper. 
    pwAlphaUpper, pwAlphaLower = computeBandThresh(nulls, 0.05, 'upper'), computeBandThresh(nulls, 0.05, 'lower')
    globalCrossings = np.where(np.logical_or(stat > gupperBand, stat < glowerBand))[0]
    
    if globalCrossings.shape[0] > 0:
        pointwiseCrossings = np.where(np.logical_or(stat > pwAlphaUpper, stat < pwAlphaLower))[0]
    else:
        globalCrossings, pointwiseCrossings = None, None
            
    if plot:
        plt.plot(nulls.T, 'k', alpha=0.4)
        plt.plot(stat,'g')
        plt.xlabel("Linearized Position, Long Axis of Alley")
        plt.ylabel("Difference in Firing Rate")
        plt.title(f"Permutation Test Results for Texture {txtX} vs {txtY} on Alley {alley}")
        for band, style in zip([glowerBand, gupperBand, pwAlphaLower, pwAlphaUpper], ['r', 'r', 'r--', 'r--']):
            plt.plot(band, style)
            
    if returnInfo:
        return globalCrossings, pointwiseCrossings
    
    

In [12]:
def permutationResultsLogger(d,fname):
    with open(fname+'.csv', "w") as f:
        w = csv.writer(f, delimiter = ' ')
        for alley in range(1,18):
            w.writerow([alley])
            for pair in ["AB", "BC", "CA"]:
                w.writerow([pair])
                for crossType in ["global", "pointwise"]:
                    w.writerow([crossType, d[alley][pair][crossType]])
        f.close()
                

In [13]:
def unitPermutationTest_AllPairsAllAlleys(unit, nnulls,fpath):
    """
    Wrapper function to complete permutation tests for a unit
    across all alleys and all pairwise stim (A,B,C) combinations
    
    Pointwise p-value is set to 0.05
    Global p-value is set to 0.00098 (0.05/(3*17))
    
    Crossings are saved to a file for later use. 
    """
    pairs = ["AB", "BC", "CA"]
    fname = fpath + unit.name + "_permutationResults"
    crossings = {i:{pair:{'global':"XXX", 'pointwise':"XXX"} for pair in pairs} for i in range(1,18)}
    for alley in range(1,18):
        print(alley)
        for pair in pairs:
            txtX, txtY = pair[0], pair[1]
            globalCrossings, pointwiseCrossings = unitPermutationTest_SinglePair(unit, alley, txtX, txtY, nnulls, 
                                                                                       plot=False, returnInfo=True)
            if globalCrossings is not None:
                crossings[alley][pair]['global'] = globalCrossings
                crossings[alley][pair]['pointwise'] = pointwiseCrossings
    permutationResultsLogger(crossings, fname)
            

In [28]:
exp = "RFD7"
datafile = f'E:\\Ratterdam\\R765\\R765{exp}\\'
behav = core.BehavioralData(datafile, f"{exp}", velo)
fpath = f'C:\\Users\\whockei1\\Google Drive\\KnierimLab\\Ratterdam\\Data\\R765\\permutationTests\\{exp}\\'
ts, position, alleyTracking, alleyVisits,  txtVisits = behav.loadData()

for subdir, dirs, fs in os.walk(datafile):
    for f in fs:
        if 'cl-maze' in f and 'OLD' not in f and 'Undefined' not in f:
            clustname = subdir[subdir.index("TT"):] + "\\" + f
            print(clustname)
            unit = core.UnitData(clustname, datafile, "", alleyBounds, alleyVisits, txtVisits, position, ts)
            unit.loadData_raw()
            unitPermutationTest_AllPairsAllAlleys(unit, 500,fpath)

TT15\cl-maze1.1


  n = (hs*np.reciprocal(ho))*33
  n = (hs*np.reciprocal(ho))*33
  Z=VV/WW
  n = (ls* np.reciprocal(lo)) * 33
  n = (ls* np.reciprocal(lo)) * 33
  W=0*U.copy()+1


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT15\cl-maze1.2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT15\cl-maze1.3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT15\cl-maze1.4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT15\cl-maze1.5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT4\cl-maze1.9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TT6\cl-maze1.1
1
2
3
4
5
6
7
8
9
10
11


In [24]:
from importlib import reload

In [25]:
reload(core)

<module 'ratterdam_CoreDataStructures' from 'C:\\Users\\whockei1\\Google Drive\\KnierimLab\\Ratterdam\\Code\\ratterdam_CoreDataStructures.py'>