In [None]:
import os
import math
import numpy as np
from numpy import genfromtxt
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from sklearn.decomposition import PCA
from mpl_toolkits import mplot3d
from numpy import random as rd
import time
import pylab as pl
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
display(HTML("""
<style>
.container { width:100% !important; }
.output {
    display: flex;
    align-items: center;
    text-align: center;
}
</style>
"""))
from IPython import display

### Load traces, normalize them, and plot them
> #### Specify path to data and whether the data should be normalized

In [None]:
%matplotlib inline
firingRates = np.genfromtxt("/media/joezaki/CSstorage/Joe/RLI_Experiment/RLI2/RLI2_Habituation/RLI2-3_HabituationDay2AndSetHeight/RLI2-3_HabituationDay2AndSetHeightRLI2_Habituation_minian_spikes.csv", delimiter=',')
normalize = True
Time = np.arange(0, firingRates.shape[1]/30, 1/30)

if normalize:
    for i in range(firingRates.shape[0]):
        firingRates[i,:]  = firingRates[i,:]/firingRates[i,:].max()

# Plot all currents as heatmap
plt.figure(figsize=(25,8))
plt.imshow(firingRates, aspect='auto', cmap='viridis')
plt.colorbar()
plt.show()

# Plot individual currents
plt.figure(figsize=(20,8))
for i in range(30):
    plt.plot(firingRates[i,:] + i*firingRates[i,:].max(), linewidth=1)
plt.ylabel("Current")
plt.xlabel("Time")
plt.show()

### Run PCA, project rates onto first PCs, plot sorted eigenvalues & plot first eigenvectors
> - #### Transpose firingRates (i.e. firingRates.copy().T on line 2) if you would like to collapse across cell dimension (for state space analysis)
> - #### *numProjections*: Specify number of PCs to project data onto (usually 3)
> - #### *numvectors*: Specify number of eigenvectors to plot

In [None]:
%matplotlib inline
rates = firingRates.copy().T
rates -= np.mean(rates, axis=0)
covMat = (1.0/(rates.shape[0]-1))*(rates.T @ rates) # covariance matrix
#evalues, evectors = np.linalg.eig(covMat) # eigenvalues and eigenvectors
evectors, evalues, V = np.linalg.svd(covMat)
numProjections = 3 # number of dimensions
projections = np.zeros((numProjections,rates.shape[0]))
for x in range(numProjections):
    projections[x,:] = np.dot(rates, evectors.T[x]) # projections of principal components onto firing rates

numvectors = 3
plt.figure(figsize=(20,8))
plt.scatter(x=np.arange(len(evalues)), y=evalues, marker='o', s=20)
plt.xlabel('Eigenvalues (Principal Components)', fontsize=20)
plt.ylabel('Explained Variance', fontsize=20)
plt.title('Explained Variance Per Principal Component', fontsize=30)
plt.show()
plt.figure(figsize=(30,10))
for n in range(numvectors):
    plt.scatter(range(evectors.shape[1]), evectors.T[n] - n, s=10)
    plt.vlines(x=range(evectors.shape[1]), ymin=-n, ymax=-n+evectors.T[n], linewidth=2, alpha=0.5)
    plt.hlines(y=-n, xmin=0, xmax=evectors.shape[1], linewidth=1, alpha=0.5)
    plt.text(x=-4, y=-n, s='PC ' + str(n+1), fontsize=15, c='black')
plt.vlines(x=range(evectors.shape[1]), ymin=-numvectors+1, ymax=0, linestyles=':', alpha=0.2)
unitnames = [str(x) for x in range(evectors.shape[1])]
for txt in range(evectors.shape[1]):
    plt.text(y = -numvectors/2, x = np.arange(evectors.shape[1])[txt], s = unitnames[txt], fontsize=7, horizontalalignment='center')
plt.ylabel('Loadings', fontsize=20)
plt.xlabel('Units', fontsize=20)
plt.title('Eigenvectors for first ' + str(numvectors) + " PCs", fontsize=30)
plt.show()

### Plot 3D PC space color-coded based on the progression of time

In [None]:
%matplotlib notebook
fig = plt.figure(figsize=(25,10))
ax = fig.add_subplot(111, projection='3d')
ax.plot(projections[0,:], projections[1,:], projections[2,:], c='grey', alpha=0.4, linewidth=1)
ax.scatter(projections[0,0], projections[1,0], projections[2,0], marker='*', c='green', s=100) # start timepoint
stimScatterthroughTime = ax.scatter(projections[0,:], projections[1,:], projections[2,:], '.', c = Time, cmap='viridis', s=30, alpha=0.4)
fig.colorbar(stimScatterthroughTime)
ax.set_xlabel('First Principal Component', fontsize=10)
ax.set_ylabel('Second Principal Component', fontsize=10)
ax.set_zlabel('Third Principal Component', fontsize=10)
fig.show()

### Plot animation of progression of time in 3D PC space
> #### Press 'i' twice to interrupt kernel and stop running animation

In [None]:
%matplotlib inline

for i in range(projections.shape[1]-1):
    fig = plt.figure(figsize=(22,10))
    ax = fig.add_subplot(111, projection='3d')

    ax.plot(projections[0,:i+1], projections[1,:i+1], projections[2,:i+1], c='grey', alpha=0.4, linewidth=1) # grey line
    ax.scatter(projections[0,0], projections[1,0], projections[2,0], marker='*', c='green', s=100) # start timepoint as green star
    stimScatterthroughTime = ax.scatter(projections[0,:i+1], projections[1,:i+1], projections[2,:i+1], '.', c='salmon', s=30, alpha=0.4) # add each next point in salmon
    ax.scatter(projections[0,i], projections[1,i], projections[2,i], '.', c='darkred', s=30) # mark current point as dark red

    timepatch = mpatches.Patch(color='darkred',label=("Time: " + str(np.round(i/30, decimals=3)) + 'sec'))
    coordpatch = mpatches.Patch(color='darkred', label="Coord: " + str(np.round(projections[0,i+1], decimals=3)) + ',' +
                                str(np.round(projections[1,i+1], decimals=3)) + ',' +
                                str(np.round(projections[2,i+1], decimals=3)))
    plt.legend(handles = [coordpatch, timepatch],loc='upper right',fontsize=15)
    # Fix axes in place from the start to the min and max of each axes
    ax.set_xlim(projections[0,:].min(),projections[0,:].max())
    ax.set_ylim(projections[1,:].min(),projections[1,:].max())
    ax.set_zlim(projections[2,:].min(),projections[2,:].max())
    ax.set_xlabel('First Principal Component', fontsize=10)
    ax.set_ylabel('Second Principal Component', fontsize=10)
    ax.set_zlabel('Third Principal Component', fontsize=10)

    display.clear_output(wait=True)
    display.display(pl.gcf())
    time.sleep(0.00001)
    plt.close(fig)

### Plot cross correlation of rates

In [None]:
%matplotlib inline
plt.figure(figsize=(12,10))
plt.imshow(np.corrcoef(firingRates))
plt.colorbar()
plt.show()