In [1]:
import sklearn as skl
import matplotlib.pyplot as plt
from sklearn import svm, preprocessing, metrics
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import classification_report, precision_score, recall_score, f1_score, accuracy_score
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from scipy.ndimage import center_of_mass


import numpy as np, random, json, pickle, datetime, copy, socket, os, sys, scipy
from scipy.stats import sem
import matplotlib.colors as colors
from importlib import reload

if socket.gethostname() == 'Tolman':
    codeDirBase = 'C:\\Users\\whockei1\\Google Drive'
elif socket.gethostname() == 'DESKTOP-BECTOJ9':
    codeDirBase = 'C:\\Users\\whock\\Google Drive'
    
sys.path.insert(0, codeDirBase + '\\KnierimLab\\Ratterdam\\Code')
import utility_fx as util
import ratterdam_ParseBehavior as Parse
import ratterdam_CoreDataStructures as Core
import ratterdam_PermutationTests as Perm
import ratterdam_Defaults as Def
import ratterdam_DataFiltering as Filt


from itertools import cycle



In [2]:
%qtconsole --style native
%matplotlib qt5

In [27]:
datafile = "E:\\Ratterdam\\R808\\R808_Beltway_D7\\"
expCode = "BRD7"
alleyTracking, alleyVisits,  txtVisits, p_sess, ts_sess = Parse.getDaysBehavioralData(datafile, expCode)
population = {}
for subdir, dirs, fs in os.walk(datafile):
    for f in fs:
        if 'cl-maze1' in f and 'OLD' not in f and 'Undefined' not in f:
            clustname = subdir[subdir.index("TT"):] + "\\" + f
            unit = Core.UnitData(clustname, datafile, expCode, Def.alleyBounds, alleyVisits, txtVisits, p_sess, ts_sess)
            unit.loadData_raw()
            rm = util.makeRM(unit.spikes, unit.position)            
            if np.nanpercentile(rm,Def.wholetrack_imshow_pct_cutoff) >= 1.:
                print(clustname)
                population[unit.name+"___"] = unit

  n = (hs*np.reciprocal(ho))*30
  n = (hs*np.reciprocal(ho))*30
  n = (ls* np.reciprocal(lo)) * 30
  n = (ls* np.reciprocal(lo)) * 30
  Z=VV/WW
  n = (hs*np.reciprocal(ho))*30
  n = (hs*np.reciprocal(ho))*30


TT1\cl-maze1.1
TT1\cl-maze1.2
TT12\cl-maze1.1
TT14\cl-maze1.2
TT14\cl-maze1.3
TT15\cl-maze1.1
TT15\cl-maze1.5
TT15\cl-maze1.6
TT6_0001\cl-maze1.1
TT6_0001\cl-maze1.2
TT6_0001\cl-maze1.3
TT9\cl-maze1.1


### Define Helper Functions and Structures for Decoding (not general enough to warrant inclusion in utility_fx.py)

In [3]:
def compute_epoch(val,size):
    """hardcode that session is divided
    into thirds. find in which third the trial is in"""
    propthrusess = val/size
    if propthrusess < 0.25:
        epoch = '0'
    elif propthrusess <= 0.5:
        epoch = '1'
    elif propthrusess <= 0.75:
        epoch = '2'
    elif propthrusess < 1.:
        epoch = '3'
    return epoch

def checkRM(ratemap):
    """Utility function to take a 1-d linear ratemap
    and see if it is valid.
    May 2019: it's not empty ie. there's data
    and the nanmax of that data exceeds firing rate
    thresh defined locally"""
    if type(ratemap) == np.ndarray and np.nanmax(ratemap) >= frThresh:
        return True
    else:
        return False
    
def generateLabel(target, alley, stimulus, epoch):
    if target == 'Alley':
        label = str(alley)
    elif target == 'Stimulus':
        label = stimulus
    elif target == 'Epoch':
        label = epoch
    elif target == 'AlleyXStimulus':
        label = f"{alley}{stimulus}"
    elif target == 'AlleyXEpoch':
        label  = f"{alley}{epoch}"
    elif target == 'StimulusXEpoch':
        label = f"{stimulus}{epoch}"
    elif target == 'AlleyXStimulusXEpoch':
        label = f"{alley}{stimulus}{epoch}"
    
    return label

def shuffle(Y, target):
    """Shuffles list of labels in place according to target value"""
    if target == 'AlleyXStimulus':
        txt = [i[-1] for i in Y]
        np.random.shuffle(txt)
        a = [i[:-1] for i in Y]
        Y = [f"{x}{y}" for x,y in zip(a,txt)]
    elif target == 'Stimulus':
        np.random.shuffle(Y)
    elif target == 'Alley':
        np.random.shuffle(Y)
    return None

target_classes = {'Alley':[str(i) for i in [16,17,3,1,5,7,8,10,11]],
                  'Stimulus': ['A','B','C'],
                  'AlleyXStimulus': ['10A', '10B', '10C', '11A', '11B', '11C', '16A', '16B', '16C',
                                    '17A', '17B', '17C', '1A', '1B', '1C', '3A', '3B', '3C', '5A', '5B',
                                   '5C', '7A', '7B', '7C', '8A', '8B', '8C']
                 }


## Define Defaults

In [4]:
frThresh = 0 #measured in Hz. Pick something close to 0, or 0 itself. 
target = 'Alley' #choices are Alley, Texture, Epoch, or some 2- or 3-member combination of these
beltwayAlleys = [16, 17, 3, 1, 5, 7, 8, 10, 11] # beltway alley IDs in terms of their full track, 17-alley ID
nbins = Def.singleAlleyBins[0]-1
avgType = 'macro' # for signal detection / performance metrics which are not inherently multiclass (e.g. all but accuracy), pick how to aggregate individual class results
nRuns = 250# number of repeats for multiple subsampling
shuffle = False
split_size = 0.75 # defined in terms of train size, proportion 0-1

## Create data matrix X and label list Y
### where each row is a single unit's response to a visit under a certain alley/txt combo

In [358]:
X = np.empty((0, nbins+1))
Y = []

for alley in beltwayAlleys:
    visitSize = len(population[list(population.keys())[0]].alleys[alley]) # all units have same behavioral data obviously so use first unit by default to get num viists to alley.
    for visitNum in range(60): 
        #dataRow = np.empty((0))
        #epoch = compute_epoch(visitNum, visitSize)
        
        invalidRow = 0 # initialize to valid, set to invalid upon finding an invalid rm
        
        unitID = 1
        for unitname, Unit in population.items():
            dataRow = np.empty((0))
            try:
                stimulus = population[list(population.keys())[0]].alleys[alley][visitNum]['metadata']['stimulus'] #again, stims are same for all units so use first unit to grab it
                label =  generateLabel(target, alley, stimulus, '')
                rm = Unit.alleys[alley][visitNum]['ratemap1d']
                if checkRM(rm)== True:
                    rm = np.hstack((unitID, rm))
                    dataRow = np.concatenate((dataRow, rm))
                    Y.append(label)
                    X = np.vstack((X, dataRow))

                else:
                    dataRow = np.concatenate((dataRow, np.zeros((nbins+1)) ))
            except:
                pass
            
            unitID += 10
                
        # Y.append(label)
        # X = np.vstack((X, dataRow))

X[np.where(~np.isfinite(X))] = 0
X = preprocessing.StandardScaler().fit_transform(X)


if shuffle is True:
    if target == 'AlleyXStimulus':
        txt = [i[-1] for i in Y]
        np.random.shuffle(txt)
        a = [i[:-1] for i in Y]
        Y = [f"{x}{y}" for x,y in zip(a,txt)]
    elif target == 'Stimulus':
        np.random.shuffle(Y)
    elif target == 'Alley':
        np.random.shuffle(Y)


### Creating X and Y
#### where each row is whole population's response to a visit under a given alley/txt combo

In [535]:
X = np.empty((0, nbins*len(population.keys())))
Y = []

for alley in beltwayAlleys:
    visitSize = len(population[list(population.keys())[0]].alleys[alley]) # all units have same behavioral data obviously so use first unit by default to get num viists to alley.
    for visitNum in range(60): 
        
        invalidRow = 0 # initialize to valid, set to invalid upon finding an invalid rm
        unitID = 0
        
        dataRow = np.empty((0))
        for unitname, Unit in population.items():
            try:
                stimulus = population[list(population.keys())[0]].alleys[alley][visitNum]['metadata']['stimulus'] #again, stims are same for all units so use first unit to grab it
                label =  generateLabel(target, alley, stimulus, '')
                rm = Unit.alleys[alley][visitNum]['ratemap1d']
                if checkRM(rm)== True:
                    #rm = np.hstack((unitID, rm))
                    dataRow = np.hstack((dataRow, rm))
                else:

                    dataRow = np.hstack((dataRow, np.zeros((nbins)) ))
            except:
                dataRow = np.hstack((dataRow, np.zeros((nbins)) ))
            
            unitID += 1
            
        Y.append(label)
        X = np.vstack((X, dataRow))
                

X[np.where(~np.isfinite(X))] = 0
X = preprocessing.StandardScaler().fit_transform(X)


if shuffle is True:
    if target == 'AlleyXStimulus':
        txt = [i[-1] for i in Y]
        np.random.shuffle(txt)
        a = [i[:-1] for i in Y]
        Y = [f"{x}{y}" for x,y in zip(a,txt)]
    elif target == 'Stimulus':
        np.random.shuffle(Y)
    elif target == 'Alley':
        np.random.shuffle(Y)


## Random Forest Classifier 

In [394]:
oobs, precisions, recalls, f1s, accuracies = [], [], [], [],[]

for i in range(500):
    print(i)
    clf = RandomForestClassifier(n_estimators=1500, oob_score=True)
    Xtrain, Xtest, ytrain, ytest = train_test_split(X,Y,shuffle=True,random_state=0)
    clf.fit(Xtrain,ytrain)
    oobs.append(clf.oob_score_)
    yfit = clf.predict(Xtest)
    p = precision_score(ytest, yfit, average=avgType)
    r = recall_score(ytest, yfit, average=avgType)
    f1 = f1_score(ytest, yfit, average=avgType)
    acc = accuracy_score(ytest,yfit)
    precisions.append(p)
    recalls.append(r)
    f1s.append(f1)
    accuracies.append(acc)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69


KeyboardInterrupt: 

## Plotting Histograms of Signal Detection Metrics

In [395]:
plt.figure()
plt.hist(precisions, color='b', alpha=0.5)
plt.hist(recalls, color='r', alpha=0.5)
plt.hist(f1s, color='g', alpha=0.5)
plt.hist(accuracies, color='k', alpha=0.5)
plt.hist(oobs,color='purple',alpha=0.5)
plt.legend(["Precision", "Recall", "F1 Score", "Accuracy", "OOB"])
#plt.vlines(weightedChance(.56,.3),0,plt.ylim()[1])
plt.ylabel("Frquency")
plt.xlabel("Performance")
plt.title(f"RF Decoding Performance Metrics on {target}, {(lambda x: 'Real' if x == False else 'Shuffle')(shuffle)}")

Text(0.5,1,'RF Decoding Performance Metrics on Alley, Real')

## Multiday Figures

In [2]:
fig1, ax1 = plt.subplots(3,1)
fig2, ax2 = plt.subplots(3,1)
fig3, ax3 = plt.subplots(3,1)
fig1.suptitle("Alley Decoding",fontsize=24)
fig2.suptitle("Stimulus Decoding",fontsize=24)
fig3.suptitle("Alley-by-Stimulus Decoding",fontsize=24)

Text(0.5,0.98,'Alley-by-Stimulus Decoding')

In [20]:
figures = {1:fig1, 2:fig2, 3:fig3}
for i in [1,2,3]:
    figures[i].axes[0].set_ylabel("Out-of-Bag Score", fontsize=24)
    figures[i].axes[1].set_ylabel("Precision", fontsize=24)
    figures[i].axes[2].set_ylabel("Recall", fontsize=24)
    figures[i].axes[2].set_xlabel("Recording Days (Real paired with shuffle)", fontsize=24)

In [25]:
# alley
# real_files = ["RFdecoding_Alley_R781BRD3_20191009_2157",
#         "RFdecoding_Alley_R781BRD4_20191010_1514",
#         "RFdecoding_Alley_R808BRD4_20191010_1853",
#         "RFdecoding_Alley_R808BRD5_20191010_2027",
#         "RFdecoding_R808BRD6_20191009_1900",
#         "RFdecoding_Alley_R808BRD7_20191010_1201"]

# shuffle_files = ["RFdecoding_Alley_R781BRD3_20191009_2304",
#             "RFdecoding_Alley_R781BRD4_20191010_1628",
#             "RFdecoding_Alley_R808BRD4_20191010_1931",
#             "RFdecoding_Alley_R808BRD5_20191010_2134",
#             "RFdecoding_R808BRD6_20191009_1950",
#             "RFdecoding_Alley_R808BRD7_20191010_1332"    
#             ]

#stimulus

# real_files = ["RFdecoding_Stimulus_R781BRD3_20191009_2216",
#              "RFdecoding_Stimulus_R781BRD4_20191010_1535",
#              "RFdecoding_Stimulus_R808BRD4_20191010_1905",
#              "RFdecoding_Stimulus_R808BRD5_20191010_2048",
#              "RFdecoding_R808BRD6_20191009_1915",
#              "RFdecoding_Stimulus_R808BRD7_20191010_1231"]

# shuffle_files = ["RFdecoding_Stimulus_R781BRD3_20191009_2323",
#                 "RFdecoding_Stimulus_R781BRD4_20191010_1653",
#                 "RFdecoding_Stimulus_R808BRD4_20191010_1943",
#                 "RFdecoding_Stimulus_R808BRD5_20191010_2154",
#                 "RFdecoding_R808BRD6_20191009_2006",
#                 "RFdecoding_Stimulus_R808BRD7_20191010_1401"]

# alley-by-stimulus

real_files = ["RFdecoding_AlleyXStimulus_R781BRD3_20191009_2242",
             "RFdecoding_AlleyXStimulus_R781BRD4_20191010_1600",
             "RFdecoding_AlleyXStimulus_R808BRD4_20191010_1919",
             "RFdecoding_AlleyXStimulus_R808BRD5_20191010_2110",
             "RFdecoding_R808BRD6_20191009_1932",
             "RFdecoding_AlleyXStimulus_R808BRD7_20191010_1300"]

shuffle_files = ["RFdecoding_AlleyXStimulus_R781BRD3_20191009_2349",
                 "RFdecoding_AlleyXStimulus_R781BRD4_20191010_1719",
                 "RFdecoding_AlleyXStimulus_R808BRD4_20191010_1957",
                 "RFdecoding_AlleyXStimulus_R808BRD5_20191010_2217",
                 "RFdecoding_R808BRD6_20191009_2023",
                 "RFdecoding_AlleyXStimulus_R808BRD7_20191010_1430"]

In [4]:
%qtconsole --style native

In [26]:
f = fig3
p=0
for i in range(6):
    s = shuffle_files[i]
    r = real_files[i]
    with open('E:\\Ratterdam\\multidayFigures\\randomForest\\'+s+".json","r") as file:
        shuffle = json.load(file)
    with open('E:\\Ratterdam\\multidayFigures\\randomForest\\'+r+".json","r") as file:
        real = json.load(file)
        
    for data  in [shuffle,real]:
        for metric,a,c in zip(['oobs', 'precisions', 'recalls'],[0,1,2],['green','blue','red']):
            vp = f.axes[a].violinplot(data[metric],[p],showmeans=True,widths=1.5)
            for vp_element in ['cmeans','cmins','cmaxes','cbars']:
                vp[vp_element].set_facecolor(c)
                vp[vp_element].set_edgecolor(c)
            for part in vp['bodies']:
                part.set_facecolor(c)
                part.set_edgecolor(c)

        p += 2
for a in [0,1,2]:
    f.axes[a].set_xticks([])
    f.axes[a].tick_params(axis='y',labelsize=20)
    f.axes[a].spines['right'].set_visible(False)
    f.axes[a].spines['top'].set_visible(False)
f.axes[2].set_xticks([0,2])
f.axes[2].set_xticklabels(["Shuffle", "Real"])
for tick in f.axes[2].xaxis.get_major_ticks():
    tick.label.set_fontsize(20)

# Non spatial decoding - representing responses in different ways besides the firing rate
#### first thing mid dec 2019 - test stat diffs, deviation of a single visit rm from [window median, window mean, etc], summary parm vec [max, mean, etc of a visit]

In [7]:
def checkRM(ratemap):
    """Utility function to take a 1-d linear ratemap
    and see if it is valid.
    May 2019: it's not empty ie. there's data
    and the nanmax of that data exceeds firing rate
    thresh defined locally"""
    if type(ratemap) == np.ndarray and np.nanmax(ratemap) >= frThresh:
        return True
    else:
        return False

In [8]:
datafile = "Z:\\CheetahData\\R859BRD3\\"
expCode = "BRD3"
alleyTracking, alleyVisits,  txtVisits, p_sess, ts_sess = Parse.getDaysBehavioralData(datafile, expCode)
population = {}
for subdir, dirs, fs in os.walk(datafile):
    for f in fs:
        if 'cl-maze1' in f and 'OLD' not in f and 'Undefined' not in f:
            clustname = subdir[subdir.index("TT"):] + "\\" + f
            unit = Core.UnitData(clustname, datafile, expCode, Def.alleyBounds, alleyVisits, txtVisits, p_sess, ts_sess)
            unit.loadData_raw()
            rm = util.makeRM(unit.spikes, unit.position)            
            if np.nanpercentile(rm,Def.wholetrack_imshow_pct_cutoff) >= 1.:
                print(clustname)
                population[unit.name+"___"] = unit

22
22
22
22
22
22
22
22
22
35


  n = (hs*np.reciprocal(ho))*30
  n = (hs*np.reciprocal(ho))*30
  n = (ls* np.reciprocal(lo)) * 30
  n = (ls* np.reciprocal(lo)) * 30
  Z=VV/WW
  n = (hs*np.reciprocal(ho))*30
  n = (hs*np.reciprocal(ho))*30


TT10\cl-maze1.1
TT10\cl-maze1.2
TT10\cl-maze1.3
TT10\cl-maze1.4
TT10\cl-maze1.6
TT10\cl-maze1.7
TT11\cl-maze1.1
TT13\cl-maze1.1
TT13\cl-maze1.2
TT13\cl-maze1.4
TT13\cl-maze1.5
TT13\cl-maze1.7
TT14\cl-maze1.1
TT14\cl-maze1.3
TT3_0001\cl-maze1.1
TT3_0001\cl-maze1.3
TT3_0001\cl-maze1.4
TT3_0001\cl-maze1.5
TT3_0001\cl-maze1.6
TT3_0001\cl-maze1.7
TT4_0001\cl-maze1.1
TT4_0001\cl-maze1.2
TT4_0001\cl-maze1.3
TT4_0001\cl-maze1.4
TT4_0001\cl-maze1.5
TT5_0001\cl-maze1.2
TT5_0001\cl-maze1.3
TT6_0001\cl-maze1.1
TT6_0001\cl-maze1.2
TT6_0001\cl-maze1.3
TT6_0001\cl-maze1.4
TT6_0001\cl-maze1.5
TT6_0001\cl-maze1.6
TT6_0001\cl-maze1.7
TT6_0001\cl-maze1.8
TT7\cl-maze1.2
TT7\cl-maze1.3
TT7\cl-maze1.4
TT8\cl-maze1.1
TT8\cl-maze1.2
TT8\cl-maze1.3
TT8\cl-maze1.4
TT8\cl-maze1.5
TT8\cl-maze1.6
TT9\cl-maze1.10
TT9\cl-maze1.2
TT9\cl-maze1.3
TT9\cl-maze1.4
TT9\cl-maze1.5
TT9\cl-maze1.6
TT9\cl-maze1.7
TT9\cl-maze1.8
TT9\cl-maze1.9


In [9]:
frThresh = 0 #measured in Hz. Pick something close to 0, or 0 itself. 
target = 'Alley' #choices are Alley, Texture, Epoch, or some 2- or 3-member combination of these
beltwayAlleys = [16, 17, 3, 1, 5, 7, 8, 10, 11] # beltway alley IDs in terms of their full track, 17-alley ID
nbins = Def.singleAlleyBins[0]-1
avgType = 'macro' # for signal detection / performance metrics which are not inherently multiclass (e.g. all but accuracy), pick how to aggregate individual class results
nRuns = 250# number of repeats for multiple subsampling
shuffle = False
split_size = 0.75 # defined in terms of train size, proportion 0-1

In [10]:
def calcStatPair(unit, alley, pair):
    x,y = pair #unpack the pair of stimuli eg AB
    rma, rmb = unit.linRMS[alley][x], unit.linRMS[alley][y]
    stat = np.abs(rma-rmb)
    return stat

In [11]:
def generateLabel2(target, alley, stimulus):
    if target == 'Alley':
        label = str(alley)
    elif target == 'Stimulus':
        if len(stimulus) == 1:
            label = stimulus
        elif len(stimulus) == 2:
            label = stimulus[0] # if label is a pair of stims for a test stat label
    elif target == 'AlleyXStimulus':
        if len(stimulus) == 1:
            label = f"{alley}{stimulus}"
        elif len(stimulus) == 2:
            label = f"{alley}{stimulus[0]}"
    
    return label

In [12]:
def calcRMdeviation(unit, alley, visit, window=5, deviation='diff'):
    """
    Compute the deviation of a visit's ratemap
    for a cell from the average within a window
    of all visits' rms. Window var is one-sided
    so whole window size is 2*window. 
    
    Deviation is 
    - 'diff' for binwise difference
    """
                    
    rmvisit = unit.alleys[alley][visit]['ratemap1d']
    rms_in_window = np.empty((0, nbins))
    
    if visit < window:
        for i in range(0, (window-visit)+1):
            rm = unit.alleys[alley][i]['ratemap1d']
            rms_in_window = np.vstack((rms_in_window, rm))
            
    elif (visit+window) >= len(unit.alleys[alley]):
        for i in range(visit-window, visit+(len(unit.alleys[alley])-visit)):
            rm = unit.alleys[alley][i]['ratemap1d']
            rms_in_window = np.vstack((rms_in_window, rm))     
            
    else:
        for i in range(visit-window, visit+window+1):
            rm = unit.alleys[alley][i]['ratemap1d']
            rms_in_window = np.vstack((rms_in_window, rm))
            
    mask = np.ma.masked_invalid(rms_in_window)
    avg = np.median(mask.data,axis=0) # ignores inf and nan
    if deviation == 'diff':
        stat = np.abs(rmvisit-avg)
    if deviation == 'grad': 
        stat = np.gradient(rmvisit)
    elif deviation.lower() == 'none':
        stat = rmvisit # this passthrough is so you dont have to change anything but a toggle in setup fx

    else:
        return "Incorrect deviation argument"
    
    return stat
        

In [31]:
def calcRMSummary(unit, alley, visit):
    """
    For a given pass for a given unit
    summarize the 1d ratemap into a simpler,
    explicit vector of attributes
    - period within session (half, third?,)
    - max
    - min
    - mean
    - loc max
    - loc min
    """
    rm = unit.alleys[alley][visit]['ratemap1d']
    epoch = np.digitize(visit,[0,30,60])
    maximum, minimum, mean = np.nanpercentile(rm, 95), np.nanmin(rm), np.nanmean(rm)
    locmax, locmin = np.nanargmax(rm), np.nanargmin(rm)
    auc = simps(rm)
    avgdrds = np.mean(np.abs(np.diff(rm))) # avg dr/ds change in rate / change in pos. 
    maxdrds = np.percentile(np.abs(np.diff(rm)), 95)
    com = center_of_mass(rm)
    rm = np.nan_to_num(rm)
    distA = cosine(rm, unit.linRMS[alley]['A'])
    distB = cosine(rm, unit.linRMS[alley]['B'])
    distC = cosine(rm, unit.linRMS[alley]['C'])
    #x = preprocessing.StandardScaler().fit_transform(rm)
    #pca = PCA(n_components=3)
    #principalComponents = pca.fit_transform(x)
    return np.asarray((maximum, minimum, locmax, locmin, com[0], auc))

In [14]:
def setupRFdata(target, representationFx):

    """
    This block collects data into data matrix X and label matrix Y.
    Each entry is an averaged linear ratemap for a cell under a given condition (alley/txt).
    Entry is the test statistic comparing AvB (labeled A), B vs C (labeled B) and CvA (labeled C)

    """

    X = np.empty((0, (6)*len(population)))
    Y = []


    for alley in beltwayAlleys:

        for visit in range(len(population[list(population.keys())[0]].alleys[alley])):

            stim = population[list(population.keys())[0]].alleys[alley][visit]['metadata']['stimulus']
            label = generateLabel2(target, alley, stim)
            dataRow = np.empty((0))
            
            for unitname, Unit in population.items():
                rm = representationFx(Unit, alley, visit)
                dataRow = np.hstack((dataRow, rm))
        
            Y.append(label)
            X = np.vstack((X, dataRow))

    X[np.where(~np.isfinite(X))] = 0
    #X = preprocessing.StandardScaler().fit_transform(X)

    if shuffle is True:
        if target == 'AlleyXStimulus':
            txt = [i[-1] for i in Y]
            np.random.shuffle(txt)
            a = [i[:-1] for i in Y]
            Y = [f"{x}{y}" for x,y in zip(a,txt)]
        elif target == 'Stimulus':
            np.random.shuffle(Y)
        elif target == 'Alley':
            np.random.shuffle(Y)

    return X, np.asarray(Y)

In [15]:
def runRandomForest(X, Y, runs=300, trees=700):
    oobs, precisions, recalls, f1s, accuracies = [], [], [], [],[]
    avgType = 'macro'
    for i in range(runs):
        clf = RandomForestClassifier(n_estimators=trees, 
                                     oob_score=True,
                                     max_features = None,
                                     max_depth = 4
                                    )
        Xtrain, Xtest, ytrain, ytest = train_test_split(X,Y,shuffle=True,random_state=0)
        clf.fit(Xtrain,ytrain)
        oobs.append(clf.oob_score_)
        yfit = clf.predict(Xtest)
        p = precision_score(ytest, yfit, average=avgType)
        r = recall_score(ytest, yfit, average=avgType)
        f1 = f1_score(ytest, yfit, average=avgType)
        acc = accuracy_score(ytest,yfit)
        precisions.append(p)
        recalls.append(r)
        f1s.append(f1)
        accuracies.append(acc)
    return [oobs, precisions, recalls, f1s, accuracies]

In [16]:
def randomForest_wrapper(target):
    """
    Wrapper function that takes a target variable
    to decode. Creates data and label matrices X,Y
    and runs a random forest classifier with a set
    number of trees and runs. Returns readouts:
    oob score, precision, recall, F1, accuracy for each run
    Plots the histograms of each on a single new figure
    """
    X,Y = setupRFdata(target)
    oobs, precisions, recalls, f1s, accuracies = runRandomForest(X,Y, runs=100)
    fig, ax = plt.subplots()
    ax.hist(precisions, color='b', alpha=0.5)
    ax.hist(recalls, color='r', alpha=0.5)
    ax.hist(f1s, color='g', alpha=0.5)
    ax.hist(accuracies, color='k', alpha=0.5)
    ax.hist(oobs,color='purple',alpha=0.5)
    ax.legend(["Precision", "Recall", "F1 Score", "Accuracy", "OOB"])
    ax.set_ylabel("Frquency")
    ax.set_xlabel("Performance")
    ax.set_title(f"RF Decoding Performance Metrics on {target}, {(lambda x: 'Real' if x == False else 'Shuffle')(shuffle)}")
    plt.show()
    print(target)
    return oobs, precisions, recalls, f1s, accuracies

In [18]:
from ipyparallel import Client
rc = Client()
dv = rc[:]

In [21]:
%qtconsole --style native

In [19]:
with dv.sync_imports():
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.metrics import classification_report, precision_score, recall_score, f1_score, accuracy_score
    from sklearn.model_selection import train_test_split

importing RandomForestClassifier from sklearn.ensemble on engine(s)
importing classification_report,precision_score,recall_score,f1_score,accuracy_score from sklearn.metrics on engine(s)
importing train_test_split from sklearn.model_selection on engine(s)


In [32]:
Xs, Ys = [], []

# choose here which representation function to use
repFx = calcRMSummary
for target in ['Alley', 'Stimulus', 'AlleyXStimulus']:
    X, Y  = setupRFdata(target, repFx)
    Xs.append(X)
    Ys.append(Y)

print("Finished Creating Data Matrices")

beginT = datetime.datetime.now()
results = dv.map_sync(runRandomForest, Xs, Ys)
endT = datetime.datetime.now()
print(f"Finished Decoding, took {round((endT-beginT).total_seconds()/60,2)}min")


  for dir in range(input.ndim)]
  dist = 1.0 - np.dot(u, v) / (norm(u) * norm(v))
  interpolation=interpolation)


Finished Creating Data Matrices
Finished Decoding, took 13.73min


In [33]:
plt.figure()
plt.hist(results[0][0],color='b')
plt.hist(results[1][0],color='g')
plt.hist(results[2][0],color='r')

(array([  2.,   2.,  18.,  53.,  41.,  93.,  63.,  18.,   9.,   1.]),
 array([ 0.195,  0.204,  0.213,  0.222,  0.231,  0.24 ,  0.249,  0.258,
         0.267,  0.276,  0.285]),
 <a list of 10 Patch objects>)