In [1]:
import sklearn as skl
import matplotlib.pyplot as plt
from sklearn import svm, preprocessing, metrics
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import classification_report, precision_score, recall_score, f1_score, accuracy_score
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels

import numpy as np, random, json, pickle, datetime, copy, socket, os, sys, scipy
from scipy.stats import sem
import matplotlib.colors as colors
from importlib import reload

if socket.gethostname() == 'Tolman':
    codeDirBase = 'C:\\Users\\whockei1\\Google Drive'
elif socket.gethostname() == 'DESKTOP-BECTOJ9':
    codeDirBase = 'C:\\Users\\whock\\Google Drive'
    
sys.path.insert(0, codeDirBase + '\\KnierimLab\\Ratterdam\\Code')
import utility_fx as util
import ratterdam_ParseBehavior as Parse
import ratterdam_CoreDataStructures as Core
import ratterdam_PermutationTests as Perm
import ratterdam_Defaults as Def

from itertools import cycle



In [2]:
%qtconsole --style native
%matplotlib qt5

### Load Population data into dict unitname:UnitData() object

In [3]:
#datafile = "E:\\Ratterdam\\R781\\Beltway_D3_190307\\"
datafile = "C:\\Users\\whock\\Google Drive\\KnierimLab\\Ratterdam\\Data\\R781\\Beltway_D3\\"
expCode = "BRD3"
alleyTracking, alleyVisits,  txtVisits, p_sess, ts_sess = Parse.getDaysBehavioralData(datafile, expCode)
population = {}
for subdir, dirs, fs in os.walk(datafile):
    for f in fs:
        if 'cl-maze1' in f and 'OLD' not in f and 'Undefined' not in f:
            clustname = subdir[subdir.index("TT"):] + "\\" + f
            print(clustname)
            unit = Core.UnitData(clustname, datafile, expCode, Def.alleyBounds, alleyVisits, txtVisits, p_sess, ts_sess)
            unit.loadData_raw(includeRewards=False)
            population[unit.name] = unit

TT11\cl-maze1.1


  n = (hs*np.reciprocal(ho))*30
  n = (hs*np.reciprocal(ho))*30
  n = (ls* np.reciprocal(lo)) * 30
  n = (ls* np.reciprocal(lo)) * 30
  Z=VV/WW


TT11\cl-maze1.2
TT3\cl-maze1.1
TT3\cl-maze1.2
TT3\cl-maze1.3
TT3\cl-maze1.4
TT3\cl-maze1.5
TT3\cl-maze1.6
TT3\cl-maze1.7
TT5\cl-maze1.1
TT5\cl-maze1.2
TT5\cl-maze1.3
TT5\cl-maze1.4
TT6\cl-maze1.1
TT6\cl-maze1.2
TT6\cl-maze1.3
TT6\cl-maze1.4
TT6\cl-maze1.5
TT6\cl-maze1.6
TT6\cl-maze1.7
TT6\cl-maze1.8
TT6\cl-maze1.9
TT9\cl-maze1.1
TT9\cl-maze1.2
TT9\cl-maze1.3
TT9\cl-maze1.4
TT9\cl-maze1.5
TT9\cl-maze1.6
TT9\cl-maze1.7


### Define Helper Functions for Decoding (not general enough to warrant inclusion in utility_fx.py)

In [4]:
def compute_epoch(val,size):
    """hardcode that session is divided
    into thirds. find in which third the trial is in"""
    propthrusess = val/size
    if propthrusess < 0.25:
        epoch = '0'
    elif propthrusess <= 0.5:
        epoch = '1'
    elif propthrusess <= 0.75:
        epoch = '2'
    elif propthrusess < 1.:
        epoch = '3'
    return epoch

In [5]:
def checkRM(ratemap):
    """Utility function to take a 1-d linear ratemap
    and see if it is valid.
    May 2019: it's not empty ie. there's data
    and the nanmax of that data exceeds firing rate
    thresh defined locally"""
    if type(ratemap) == np.ndarray and np.nanmax(ratemap) > frThresh:
        return True
    else:
        return False

In [6]:
def generateLabel(target, alley, stimulus, epoch):
    if target == 'Alley':
        label = str(alley)
    elif target == 'Stimulus':
        label = stimulus
    elif target == 'Epoch':
        label = epoch
    elif target == 'AlleyXStimulus':
        label = f"{alley}{stimulus}"
    elif target == 'AlleyXEpoch':
        label  = f"{alley}{epoch}"
    elif target == 'StimulusXEpoch':
        label = f"{stimulus}{epoch}"
    elif target == 'AlleyXStimulusXEpoch':
        label = f"{alley}{stimulus}{epoch}"
    
    return label

In [7]:
def generateLabel_rand(target, alley, stimulus, epoch):
    if target == 'Alley':
        label = str(alley)
    elif target == 'Stimulus':
        label = stimulus
    elif target == 'Epoch':
        label = epoch
    elif target == 'AlleyXStimulus':
        label = f"{alley}{np.random.choice(['A','B','C'])}"
    elif target == 'AlleyXEpoch':
        label  = f"{alley}{epoch}"
    elif target == 'StimulusXEpoch':
        label = f"{stimulus}{epoch}"
    elif target == 'AlleyXStimulusXEpoch':
        label = f"{alley}{stimulus}{epoch}"
    
    return label

### Define Defaults

In [71]:
frThresh = 0.0 #measured in Hz. Pick something close to 0, or 0 itself. 
target = 'AlleyXStimulus' #choices are Alley, Texture, Epoch, or some 2- or 3-member combination of these
beltwayAlleys = [16, 17, 3, 1, 5, 7, 8, 10, 11] # beltway alley IDs in terms of their full track, 17-alley ID
nbins = Def.singleAlleyBins[0]-1
avgType = 'macro' # for signal detection / performance metrics which are not inherently multiclass (e.g. all but accuracy), pick how to aggregate individual class results
nRuns = 200# number of repeats for multiple subsampling

# SVC params
C = 1e7 # found via gridsearch
gamma = 0.01 # found via gridsearch
kernel = 'rbf'
split_size = 0.75 # defined in terms of train size, proportion 0-1

### Create Data Matrix X and label vector Y

In [79]:
#X = np.empty((0, nbins*len(population.keys())))
X = np.empty((0, nbins))
Y = []

for alley in beltwayAlleys:
    visitSize = len(population[list(population.keys())[0]].alleys[alley]) # all units have same behavioral data obviously so use first unit by default to get num viists to alley.
    for visitNum in range(visitSize): 
        dataRow = np.empty((0))
        epoch = compute_epoch(visitNum, visitSize)
        stimulus = population[list(population.keys())[0]].alleys[alley][visitNum]['metadata']['stimulus'] #again, stims are same for all units so use first unit to grab it
        label = generateLabel(target, alley, stimulus, epoch)
        
        invalidRow = 0 # initialize to valid, set to invalid upon finding an invalid rm
        for unitname, Unit in population.items():
            rm = Unit.alleys[alley][visitNum]['ratemap1d']
            if checkRM(rm) == True:
                dataRow = np.concatenate((dataRow, rm))
                X = np.vstack((X, rm))
                Y.append(label)

            else:
                pass
                #dataRow = np.concatenate((dataRow, np.zeros((nbins)) ))
                
        #Y.append(label)
        #X = np.vstack((X, dataRow))

X[np.where(~np.isfinite(X))] = 0
X = preprocessing.StandardScaler().fit_transform(X)

### Classification: Support Vector Machine

In [81]:
precisions, recalls, f1s, accuracies = [], [], [], []

for i in range(nRuns):
    print(i)
    Xtrain, Xtest, ytrain, ytest = train_test_split(X, Y, shuffle=True, train_size=split_size) #default split size is 1/4
    svc = SVC(C=C, gamma=gamma,kernel=kernel)
    svc.fit(Xtrain,ytrain)
    yfit = svc.predict(Xtest)
    p = precision_score(ytest, yfit, average=avgType)
    r = recall_score(ytest, yfit, average=avgType)
    f1 = f1_score(ytest, yfit, average=avgType)
    acc = accuracy_score(ytest,yfit)
    precisions.append(p)
    recalls.append(r)
    f1s.append(f1)
    accuracies.append(acc)

0




1
2
3
4


KeyboardInterrupt: 

### Basic Visualization of Signal Detection / Performance Metrics

In [82]:
plt.figure()
plt.hist(precisions, color='b', alpha=0.5)
plt.hist(recalls, color='r', alpha=0.5)
plt.hist(f1s, color='g', alpha=0.5)
plt.hist(accuracies, color='k', alpha=0.5)
plt.legend(["Precision", "Recall", "F1 Score", "Accuracy"])
#plt.vlines(0.093,0,plt.ylim()[1])
plt.ylabel("Frquency")
plt.xlabel("Performance")
plt.title(f"SVM Decoding Performance Metrics on {target}")

Text(0.5,1,'SVM Decoding Performance Metrics on AlleyXStimulus')

#### Confusion Matrices

In [44]:
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.

Taken from the sklearn docs. Modified by WH
"""
normalize=False
classes = np.unique(Y)
# Compute multiple confusion matrces and sum
allcms = []
for run in range(1):
    Xtrain, Xtest, ytrain, y_true = train_test_split(X, Y, shuffle=True, train_size=split_size)
    svc = SVC(C=C, gamma=gamma,kernel=kernel)
    svc.fit(Xtrain,ytrain)
    y_pred = svc.predict(Xtest)
    c = confusion_matrix(y_true, y_pred)
    allcms.append(c)

cm = np.sum(np.asarray(allcms), axis=0)

# Only use the labels that appear in the data
if normalize:
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
       yticks=np.arange(cm.shape[0]),
       # ... and label them with the respective list entries
       xticklabels=classes, yticklabels=classes,
       title=[],
       ylabel='True label',
       xlabel='Predicted label')

ax.plot(range(classes.shape[0]), range(classes.shape[0]))

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        ax.text(j, i, format(cm[i, j], fmt),
                ha="center", va="center",
                color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()

