<a href="https://colab.research.google.com/github/xup5/Computational-Neuroscience-Class/blob/main/Lab%204%20Spike%20Triggered%20Covariance/Spike_Triggered_Covariance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Spike-triggered Covariance (STC)

This is a tutorial for linear visual filters (receptive fields)
and spike-triggered covariance.

(1) Examples of constructing visual Gabor filters and
filtering an image.

(2) Examples of spike-triggered aproaches to find filters.

Odelia Schwartz 2012, transcribed and modified by Xu Pan in 2022.

This is a simplified version of: 
Spike-triggered neural characterization.
Schwartz, Pillow, Rust, Simoncelli,
Journal of Vision, 2006.

## 1. Visual filters and images



### Helper functions
I made you two functions that can generate 2D sinusoidal and Gaussian images.

makeGaussian(imsize, cov)

makeSine(imsize, spatialf, ori, phase)

***You don't have to look inside. Just run it.***

In [None]:
import numpy as np

def makeGaussian(size, cov=5):
  x = np.arange(0, size, 1, float)
  y = x[:,np.newaxis]
  x0 = y0 = (size) / 2  # this is for matching old matlab code.
  gaussian = np.exp(((x-x0)**2 + (y-y0)**2) / (-2*cov))
  gaussian = gaussian/np.max(gaussian)
  return gaussian

def makeSine(imsize=10, spatialf=5, ori=0, phase=0):
  ori = ori/180*np.pi
  phi = phase/180*np.pi
  try:
    im = np.ones((imsize[0],imsize[1]))
    x0 = (imsize[0]+1) / 2 - 1
    y0 = (imsize[1]+1) / 2 - 1
  except:
    im = np.ones((imsize,imsize))
    x0 = y0 = (imsize+1) / 2 - 1 # this is for matching old matlab code.
    imsize = [imsize,imsize]
  for x in range(imsize[0]):
      for y in range(imsize[1]):
          im[x,y] = np.sin(2*np.pi/spatialf*(((x0-x)*np.sin(ori)+(y0-y)*np.cos(ori)))+phi)           
  return im


### 1a Gabor filters and images

In [None]:
import matplotlib.pyplot as plt

# Set parameters of sinusoid
sz = 20
period = 5
direction = 0
phase = 0
theSine = makeSine(sz,period,direction,phase)

# Plot the sinusoid
plt.imshow(theSine,cmap='gray')

In [None]:
# Make a 2 dimensional Gaussian and plot it
thesig = 2
theGauss = makeGaussian(sz, thesig);
plt.imshow(theGauss,cmap='gray')

In [None]:
# Make a Gabor filter, by multiplying a sinusoid with a Gaussian.
theFilt = theSine * theGauss;
plt.imshow(theFilt,cmap='gray')

In [None]:
# Load an image

# download an image from our repository
!wget https://github.com/schwartz-cnl/Computational-Neuroscience-Class/blob/main/Lab%204%20Spike%20Triggered%20Covariance/einstein.pgm?raw=true -O einstein.pgm

from skimage.io import imread
im = imread('einstein.pgm')
plt.imshow(im, cmap='gray')

### 1b Convolve the image with the filter

In [None]:
from scipy import signal

response = signal.convolve2d(im, theFilt, mode='valid')
plt.imshow(response, cmap='gray')

### To do:
1. Try making different Gabor filters by varying parameters above
(e.g., direction, priod, phase of the grating; and thesig of the Gaussian)

## 2. Spike-triggered approaches
We have constructed in advance few model neurons. We will use spike-triggered approaches to figure out the receptive field properties of the neurons.

### Neuron models
This section has 3 functions: ClassModel1, ClassModel2, ClassModel3. You can think of them as "neurons" that take stimulus as input and give response as output. (Just like what we did in Lab2.)

***You don't have to look inside. Just run it.***

In [None]:
def ClassModel1(allStim):
  xDim = 8
  kernelX = xDim                # spatial size of filter
  kernelT = 6                   # temporal size of filter
  kernelSize = kernelX * kernelT
  nFrames = allStim.shape[0]

  p = 2
  th = 180/4
  rate = 1/6
  base = 0
  itau = 1.2
  sig=1.6 
  per=4.5
  x = np.arange(1, kernelX+1, 1, float)-(kernelX+1)/2
  y = np.arange(kernelT, 0, -1, float)
  y = y[:,np.newaxis]
  v1 = np.exp(-x**2/(2*sig**2)) * np.exp(-itau*y) * y**2 * makeSine([kernelT,kernelX], per, th, 0)
  v1 = v1.flatten()
  v1 = v1/np.sqrt(np.var(v1,ddof=1))

  linResp = base + rate * np.maximum((np.matmul(allStim,v1)),0)**p
  linResp = 20*linResp/np.max(linResp)
  spikeResp = (linResp > np.random.rand(nFrames))
  spikeResp[0:(kernelT-1)] = 0  # can't use these

  return spikeResp

###############################################################################
def ClassModel2(allStim):
  xDim = 8
  kernelX = xDim                # spatial size of filter
  kernelT = 6                   # temporal size of filter
  kernelSize = kernelX * kernelT
  nFrames = allStim.shape[0]

  p = 2
  th = 180/4
  rate = 1/12
  base = 00
  itau = 1.2
  sig=1.6
  per=4.5
  x = np.arange(1, kernelX+1, 1, float)-(kernelX+1)/2
  y = np.arange(kernelT, 0, -1, float)
  y = y[:,np.newaxis]
  v1 = np.exp(-x**2/(2*sig**2)) * np.exp(-itau*y) * y**2 * makeSine([kernelT,kernelX], per, th, 0)
  v1 = v1.flatten()
  v1 = v1/np.sqrt(np.var(v1,ddof=1))
  v2 = np.exp(-x**2/(2*sig**2)) * np.exp(-itau*y) * y**2 * makeSine([kernelT,kernelX], per, th, 90)
  v2 = v2.flatten()
  v2 = v2/np.sqrt(np.var(v2,ddof=1))

  linResp = base + rate * (np.abs((np.matmul(allStim,v1)))**p + np.abs((np.matmul(allStim,v2)))**p)
  linResp = linResp/np.max(linResp)
  spikeResp = (linResp > np.random.rand(nFrames))
  spikeResp[0:(kernelT-1)] = 0  # can't use these

  return spikeResp

###############################################################################
def ClassModel3(allStim):
  xDim = 8
  kernelX = xDim                # spatial size of filter
  kernelT = 6                   # temporal size of filter
  kernelSize = kernelX * kernelT
  nFrames = allStim.shape[0]

  p = 2
  th = 180/4
  rate = 0.25
  base = 00
  itau = 1.2
  sig=1.6
  per=4.5
  x = np.arange(1, kernelX+1, 1, float)-(kernelX+1)/2
  y = np.arange(kernelT, 0, -1, float)
  y = y[:,np.newaxis]
  v1 = np.exp(-x**2/(2*sig**2)) * np.exp(-itau*y) * y**2 * makeSine([kernelT,kernelX], per, th, 0)
  v1 = v1.flatten()
  v1 = v1/np.sqrt(np.var(v1,ddof=1))
  v2 = np.exp(-x**2/(2*sig**2)) * np.exp(-itau*y) * y**2 * makeSine([kernelT,kernelX], per, th, 90)
  v2 = v2.flatten()
  v2 = v2/np.sqrt(np.var(v2,ddof=1))
  v3 = np.exp(-x**2/(2*sig**2)) * np.exp(-itau*y) * y**2 * makeSine([kernelT,kernelX], per, th+90, 0)
  v3 = v3.flatten()
  v3 = v3/np.sqrt(np.var(v3,ddof=1))

  l1 = (np.matmul(allStim,v1)>0)*(np.matmul(allStim,v1))**p # half squared
  l2 = (np.matmul(allStim,v2))**p
  l3 = (np.matmul(allStim,v3))**p

  linResp = (1+l1)/(1+0.03*l2+0.05*l3)
  linResp = 15*rate*linResp/np.max(linResp)

  spikeResp = (linResp > np.random.rand(nFrames))
  spikeResp[0:(kernelT-1)] = 0  # can't use these

  return spikeResp

### 2a. Generate random stimuli to "probe" the neuron with

In [None]:
nFrames = 500000
xDim = 8
kernelX = xDim                 # spatial size of noise stimulus
kernelT = 6                    # temporal size of noise stimulus
kernelSize = kernelX * kernelT
allStim = np.random.randn(nFrames, kernelSize)

In [None]:
# Show example frames of the white noise stimuli

fig, _ = plt.subplots(4, 4, constrained_layout=True, figsize=(8, 6))
for i,ax in enumerate(fig.axes):
  ax.imshow(np.reshape(allStim[i,:],(6,8)), cmap='gray')

### 2b. Generate spikes from a model neuron

In [None]:
# This can be toggled for different model neurons; choose from:

spikeResp = ClassModel1(allStim)
# spikeResp = ClassModel2(allStim)
# spikeResp = ClassModel3(allStim)

In [None]:
# Plot the spiking activity for the first 100 frames

plt.plot(spikeResp[1:100],'o')
plt.title('Spikes', fontsize=16)
plt.xlabel('Time (ms)', fontsize=16)

### 2c. Spike-triggered average

In [None]:
# Compute the spike-triggered average

# First find the frames for which the model neuron spiked

spikeInd=np.where(spikeResp>0.5)[0]

In [None]:
# Then find the spike-triggered stimuli, i.e., the stimuli for which
# the neuron spiked

spikeStim = allStim[spikeInd,:]
numspikes = len(spikeInd)

In [None]:
# Plot some example stimulus frames of the spike-triggered stimuli
# Can you tell by eye what in the stimulus is triggering a spike?

fig, _ = plt.subplots(4, 4, constrained_layout=True, figsize=(8, 6))
for i,ax in enumerate(fig.axes):
  ax.imshow(np.reshape(spikeStim[i,:],(6,8)), cmap='gray')

In [None]:
# We'll plot the spike-triggered average (STA)
# Is it a structured receptive field?

sta = np.mean(spikeStim, axis=0)
plt.imshow(np.reshape(sta,(6,8)), cmap='gray')

Are there other receptive fields this model neuron is using to compute its spikes?

### 2d. Spike-triggered covariance

In [None]:
# The spike-triggered average reveals changes in the mean.
# We would like richer characterizations of the neurons by looking
# for changes in the variance.
# We'll do a simple version of a spike-triggered covariance
# This is a Principal Component Analysis, computing the eigenvalues
# (variances along each receptive field axes) and the eigenvectors
# (the receptive field axes).
# Technical note: In papers, we usually first project out the STA (which we
# did not do here for simiplicity)

thecov = np.matmul(spikeStim.T, spikeStim)/(numspikes-1);
(eigval, eigvec) = np.linalg.eig(thecov)

# Order the eigval and eigvec
idx = eigval.argsort()[::-1]   
eigval = eigval[idx]
eigvec = eigvec[:,idx]

# Plot the (sorted) eigenvalues 
# This tells you which eigenvalues have variance that
# is significantly higher or lower than the rest.
plt.plot(eigval, 'o')
plt.ylabel('Variance', fontsize=16)
plt.xlabel('Ordered Eigenvalues', fontsize=16)

How many appear significant?

In [None]:
# Plot a corresponding eigenvector that appears significant(e.g., here for
# ClassModel1 set to the first, which is indice 0)
# This eigenvector corresponds to a filter/feature.receptive fiels that contributes
# to the model neuron response.
# Some model neurons may have more than one such receptive field (the ordered eigenvalues
# above tell you which are significant!)
# In one of the models, the last two eigenvalues are significant!
# For that model, change thenum1 and thenum2 to reflect the last two eigenvalues
# (e.g., 46 and 47)
# Technical note: If the STA was structured, the first eigenvector could just be the 
# STA receptive field (possibly negated)

thenum1 = 0
plt.imshow(np.reshape(eigvec[:,thenum1],(6,8)), cmap='gray')

In [None]:
# Plot another eigenvector
# Here set to the second, but change as needed...
# The second may or may not be significant in terms of the variance,
# depending on the model. In one of the models, the last two are significant!
# For that model, change thenum1 and thenum2 to reflect the last two eigenvalues
# (e.g., 46 and 47)

thenum2 = 1
plt.imshow(np.reshape(eigvec[:,thenum2],(6,8)), cmap='gray')

Is it structured? Do we expect it to be based on the eigenvalues?

In [None]:
# Look at scatter plots onto two eigenvectors or receptive fields.
# We will compare the responses to the spike-triggered stimuli with
# those to the full stimulus set. We will match the number of stimuli
# for readability of the plots.

# The two receptive field
basis2 = eigvec[:,thenum2]
basis1 = eigvec[:,thenum1]
# Responses of the two receptive fields to all stimuli
allProj = [np.matmul(allStim,basis2), np.matmul(allStim,basis1)]
# And to the spike-triggered stimuli
spikeProj = [np.matmul(spikeStim,basis2), np.matmul(spikeStim,basis1)]

thenum = min(20000, numspikes)
plt.figure(figsize=(6, 6))
plt.scatter(allProj[0][0:thenum], allProj[1][0:thenum], facecolors='none', edgecolors='b', label='All stim')
plt.scatter(spikeProj[0][0:thenum], spikeProj[1][0:thenum], facecolors='none', edgecolors='r', label='Spike stim')
plt.xlim([-5,5])
plt.ylim([-5,5])
plt.ylabel('Receptive field 1', fontsize=16)
plt.xlabel('Receptive field 2', fontsize=16)
plt.legend()

In [None]:
# Plot ellipse signifying the variances found by the Principal Component Analysis
# Technical note: model 1 has an asymmetric nonlinearity, so we only see a change
# in the mean for the spike-triggered (red) stimuli. For the other models, you
# should see a change in the variance

angles=np.linspace(0, 2*np.pi, 100)

# Variance along the 2 receptive fields
ellipse = [3*np.sqrt(eigval[thenum2])*np.cos(angles), 3*np.sqrt(eigval[thenum1])*np.sin(angles)]

# Variance along 2 other directions that are not structured
ellipse_other = [3*np.sqrt(eigval[24])*np.cos(angles), 3*np.sqrt(eigval[25])*np.sin(angles)]

# Plot the ellipses
plt.figure(figsize=(6, 6))
plt.scatter(allProj[0][0:thenum], allProj[1][0:thenum], facecolors='none', edgecolors='b', label='All stim')
plt.scatter(spikeProj[0][0:thenum], spikeProj[1][0:thenum], facecolors='none', edgecolors='r', label='Spike stim')
plt.plot(ellipse[0],ellipse[1], 'r', linewidth=3)
plt.plot(ellipse_other[0],ellipse_other[1], 'b', linewidth=3)
plt.xlim([-5,5])
plt.ylim([-5,5])
plt.ylabel('Receptive field 1', fontsize=16)
plt.xlabel('Receptive field 2', fontsize=16)

## Question:
Go through each of the model neurons in this tutorial, and describe what you found. Plot the Spike-triggered average (STA). In the spike-triggered covraiance analysis, what eigenvectors (receptive fields) had a striking high or low variance relative to the rest? Plot them. What did the scatter plot signify? Hint: we talked about these similar model neuron examples in class when we discussed the spike-triggered covariance!
