<a href="https://colab.research.google.com/github/xup5/Computational-Neuroscience-Class/blob/main/Lab%202%20Neural%20Box/neuralBox2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Neural Box 2

Odelia Schwartz, Xu Pan, Alexander Claman

Partly based on Cold Spring Harbor white noise tutorial by Chichilnisky.
This is a tutorial for figuring out the response properties
of a neuron by presenting many random stimuli.

This tutorial specifically focuses on spike-triggered average,
estimating the linear filter, and estimating the nonlinearity.

The tutorial mostly includes model neurons (synthetic data) and shows the estimation process. We also examine the spike-triggered average for real neural data from an H1 neuron in the fly. Data is from the Dayan and Abbott book. 

At the end of the tutorial, there are some questions.


## Part 1. Use synthesized data

The neuron here is a model simulation, but the same approach is used for understanding real neurons. We probe the model neuron as if it is a black box and we do not know its properties. In part 2, you will use this approach to 
understand a real neuron.

### Construct stimuli

In [None]:
# We want to first choose experimental stimuli that are random
# At each frame, the intensity of the uniform screen changes:
# it is drawn randomly, here from a Gaussian distribution.

import numpy as np

numsamples = 25000
stimulus = (1/3*np.random.randn(numsamples))

In [None]:
# Plot the stimulus

import matplotlib.pyplot as plt

thelen = min(1000, numsamples)
plt.plot(stimulus[0:thelen])

plt.xlabel('Time (msec)', fontsize=16)
plt.ylabel('Stimulus strength', fontsize=16)
plt.title('First 1 second of stimulus', fontsize=16)

In [None]:
# Let's check that it is Gaussian

plt.hist(stimulus, bins=50, density=True, align='left');
plt.xlabel('Stimulus strength', fontsize=16)
plt.ylabel('Probability', fontsize=16)

### Neuron models

This cell has 5 functions getLinear1, getLinear2, getLinear3, getNonlinear1, getNonlinear2. You can think of them as "neurons" that take stimulus as input and give response as output. The filter and nonlinearity are hidden, as they would be in a biological experiment.

***You don't have to look inside. Just run it.***

In [None]:
def getLinear1(stimulus, kernelSize):
  # Compute the linear response using a single exponential
  # lowpass filter.  You could substitute other linear filters here if you
  # wanted to, but this one is pretty simple.
  tau = 10 # time constant
  linearResp = np.zeros(len(stimulus))
  
  for i in range(len(stimulus)-1):
    # Solve the differential equation
    linearResp[i+1] = linearResp[i] + (1/tau)*(stimulus[i]-linearResp[i]) 
  
  # get the impulse response function which is also the "filter"
  impulse = np.zeros(1000) # make a impulse stimulus
  impulse[0] = 1
  impulseResp = np.zeros(len(impulse))
  for i in range(len(impulse)-1):
    # Solve the differential equation
    impulseResp[i+1] = impulseResp[i] + (1/tau)*(impulse[i]-impulseResp[i])
  impulseResp = impulseResp[0:kernelSize]
  filter = np.flipud(impulseResp)

  return (linearResp, filter)

###############################################################################
def getLinear2(stimulus, kernelSize):
  # Compute the linear response using a single exponential
  # lowpass filter.  You could substitute other linear filters here if you
  # wanted to, but this one is pretty simple.
  tau = 5 # time constant
  linearResp = np.zeros(len(stimulus))
  
  for i in range(len(stimulus)-1):
    # Solve the differential equation
    linearResp[i+1] = linearResp[i] + (1/tau)*(stimulus[i]-linearResp[i]) 
  
  # get the impulse response function which is also the "filter"
  impulse = np.zeros(1000) # make a impulse stimulus
  impulse[0] = 1
  impulseResp = np.zeros(len(impulse))
  for i in range(len(impulse)-1):
    # Solve the differential equation
    impulseResp[i+1] = impulseResp[i] + (1/tau)*(impulse[i]-impulseResp[i])
  impulseResp = impulseResp[0:kernelSize]
  filter = np.flipud(impulseResp)

  linearResp = -linearResp
  filter = -filter

  return (linearResp, filter)

###############################################################################
def getLinear3(stimulus, kernelSize):
  # Compute the linear response using a 3-stage cascade of exponential
  # lowpass filters.  You could substitute other linear filters here if you
  # wanted to, but this one is pretty simple.
  tau = 3 # time constant
  linearResp = np.zeros((len(stimulus),3))
  
  for i in range(len(stimulus)-1):
    # Solve the differential equation
    linearResp[i+1,0] = linearResp[i,0] + (1/tau)*(stimulus[i]-linearResp[i,0])
    linearResp[i+1,1] = linearResp[i,1] + (1/tau)*(linearResp[i,0]-linearResp[i,1])
    linearResp[i+1,2] = linearResp[i,2] + (1/tau)*(linearResp[i,1]-linearResp[i,2])
  
  # Getting rid of the first- and second-order filtered signals, we only
  # want the third one.
  linearResp = linearResp[:,2]
  
  # get the impulse response function which is also the "filter"
  impulse = np.zeros(1000) # make a impulse stimulus
  impulse[0] = 1
  impulseResp = np.zeros((len(stimulus),3))
  for i in range(len(impulse)-1):
    # Solve the differential equation
    impulseResp[i+1,0] = impulseResp[i,0] + (1/tau)*(impulse[i]-impulseResp[i,0])
    impulseResp[i+1,1] = impulseResp[i,1] + (1/tau)*(impulseResp[i,0]-impulseResp[i,1])
    impulseResp[i+1,2] = impulseResp[i,2] + (1/tau)*(impulseResp[i,1]-impulseResp[i,2])

  # Getting rid of the first- and second-order filtered signals, we only
  # want the third one.
  impulseResp = impulseResp[:,2]

  impulseResp = impulseResp[0:kernelSize]
  filter = np.flipud(impulseResp)

  return (linearResp, filter)

###############################################################################
def getNonlinear1(linearResp):
  nonlinearResp = np.zeros(len(linearResp))
  theind = np.where(linearResp>0)
  nonlinearResp[theind] = linearResp[theind]**2
  return nonlinearResp

###############################################################################
def getNonlinear2(linearResp):
  return linearResp**2

### Simulate a model neuron

In [None]:
# We're now going to simulate a model neuron
# For purposes of this demo, we are constructing the model
# neurons and so know their filters and nonlinearity
# (in an experiment with real neurons, we would be handed 
# the spike trains and would not know this!)

# We've made several versions of model neurons.
# We have 3 possible linear filters.
# Toggle between

kernelSize = 60
(linearResp, filter) = getLinear1(stimulus, kernelSize)
# (linearResp, filter) = getLinear2(stimulus, kernelSize)
# (linearResp, filter) = getLinear3(stimulus, kernelSize)

In [None]:
# Let's look at the filter (which we usually would not know, but here we do
# because we made up the model neurons)

plt.plot(filter, 'o-');
plt.title('Actual model filter', fontsize=16)
plt.xlabel('Time (ms)', fontsize=16);

In [None]:
# We also have two versions of nonlinearities for our model neurons.

nonlinearResp = getNonlinear1(linearResp)
# Toggle between
# nonlinearResp = getNonlinear2(linearResp)

In [None]:
# We can use this non-linear response to simulate a 
# Poisson-ish spike train... as per last class!

xr = np.random.rand(len(nonlinearResp))
spikeResponse = nonlinearResp > .05*xr


In [None]:
# So far, we constructed a model neuron and its response to experimental
# stimuli. Here's the first second of each step:

fig, axs = plt.subplots(3, constrained_layout=True, figsize=(6, 8))
axs[0].plot(linearResp[0:1000])
axs[0].set_title('Linear response', fontsize=16) 
axs[1].plot(nonlinearResp[1:1000], color='r')
axs[1].set_title('Nonlinear function of linear response', fontsize=16)
axs[2].stem(spikeResponse[1:1000], basefmt=" ")
axs[2].set_title('# of Spikes (1 ms bins)', fontsize=16)
axs[2].set_xlabel('Time (ms)', fontsize=16)

### Estimate linear filter (Spike-triggered average)

In [None]:
# Now we compute the spike-triggered average stimulus.  This is accomplished
# by taking the 60 milliseconds of stimulus immediately preceding each spike
# and adding them together.  This sum is then divided by the total number 
# of spikes fired over the course of the entire experiment to determine the 
# average stimulus preceding a spike.
# This spike-triggered average is, in a sense, a template for what the neuron
# is "looking for" in the stimulus.

kernelSize = 60
totalCount = sum(spikeResponse)

sta = np.zeros(kernelSize)

for i in range(kernelSize-1,len(spikeResponse)):
  if spikeResponse[i] == 1:                     # if there is a spike
    sta = sta + stimulus[i-kernelSize+1:i+1]    # add stimulus preceding spike

sta = sta/totalCount                    # average of stimuli that led to spike

In [None]:
# We'll first look at the answer; then unpack what we did

plt.plot(2*np.arange(-kernelSize,0), sta, marker='o')
plt.title('Estimated Spike-triggered average', fontsize=16)
plt.xlabel('Time (ms)')

In [None]:
# Because this is a tutorial, we *know* exactly what filtering 
# was done on the stimulus to get the linear response ("linearResp"). 
# Below, we compare the spike-triggered average to the filter we used.
# They are similar shape up to a constant multiplication factor


plt.plot(2*np.arange(-kernelSize,0), sta/np.sum(sta), marker='o', label='estimated filter')
plt.plot(2*np.arange(-kernelSize,0), filter/np.sum(filter), marker='o', label='actual model filter')
plt.title('Spike-triggered average', fontsize=16)
plt.xlabel('Time (ms)', fontsize=16)
plt.legend()

In [None]:
# Remember we summed together stimuli that led to a spike.
# We can look at individual such stimuli and the average 
# as we have more samples.
# Let's visualize this process.

from IPython.display import clear_output

sta_temp = np.zeros(kernelSize)
total_temp = 0
time = 2*np.arange(-kernelSize,0)

for i in range(kernelSize-1,min(10000, numsamples)): # we plot the first 10s
  if spikeResponse[i] == 1:
    sta_temp = sta_temp + stimulus[i-kernelSize+1:i+1]
    total_temp = total_temp + 1
    if total_temp % 50 == 0: # update the plot every 50 spikes
      clear_output(wait=True)
      fig, axs = plt.subplots(2, constrained_layout=True, figsize=(6, 5))
      axs[0].plot(time, stimulus[i-kernelSize+1:i+1])
      axs[0].set_title('Stimulus that resulted in spike', fontsize=16)
      axs[1].plot(time, sta_temp/total_temp)
      axs[1].set_title('Spike-triggered average', fontsize=16)
      plt.show()

In [None]:
# Extra intuition
# To get intuition about why averaging the spiked stimuli
# works, we can look at how the linear response relates to
# spikes versus no spikes. This allows us to differentiate 
# between stimuli that lead to spikes or no spikes.
# When the estimated linear response is higher, the model
# neuron is more likely to spike...

linearEst = np.zeros(len(spikeResponse))

for i in range(kernelSize, len(spikeResponse)):
  linearEst[i] = np.dot(sta, stimulus[i-kernelSize+1:i+1])

plt.scatter(linearEst, spikeResponse, facecolors='none', edgecolors='C0')
plt.xlabel('Linear response', fontsize=16)
plt.ylabel('Spikes', fontsize=16)
plt.title('Relation between estimated linear responses and spikes', fontsize=16)

### Estimate nonlinearity (Extra)

In [None]:
# We can also actually estimate the nonlinearity.
# We can plot the "average" number of spikes fired in response to similar 
# linear responses.

# First we decide on linear response bins...

plt.hist(linearResp, bins=50)
plt.xlabel('Stimulus value', fontsize=16)
plt.ylabel('Probability', fontsize=16)

In [None]:
# -.2 to .3 looks like a good range.
linear_response_bin = np.arange(-.2,.3,.05)

mean_spikes = np.zeros(len(linear_response_bin))
sem_spikes = np.zeros(len(linear_response_bin))

for i in range(len(linear_response_bin)):
  # Find when (indexes) the linear response fall in this bin:
  ind_in_bin = np.where(np.logical_and(linearResp>linear_response_bin[i], linearResp<linear_response_bin[i]+0.05))
  # Calculate the mean of the spike count over time points whose linear responses are in this bin.
  mean_spikes[i] = np.mean(spikeResponse[ind_in_bin])
  sem_spikes[i] = np.std(spikeResponse[ind_in_bin], ddof=1) / np.sqrt(len(spikeResponse[ind_in_bin]))

plt.errorbar(linear_response_bin+0.025, mean_spikes, yerr=sem_spikes, capsize=3)
plt.title('Estimated nonlinear function', fontsize=16)
plt.xlabel('Linear response component', fontsize=16)
plt.ylabel('Mean spike count', fontsize=16)

In [None]:
# We can compare this to the nonlinearity that we know because we constructed
# the model simulation - but usually would not know.
# That is, we can superimpose the non-linear function that we actually used to
# determine the spike firing probabilities.  The plot of linear response versus
# mean spike count should have the same shape as this function, but, there
# is an arbitrary scale factor relating these two quantities.  Below, we estimate
# this scale factor using least-squares.

xx = getNonlinear1(linear_response_bin)              # this we know
gamma = 1/np.dot(xx, xx) * np.dot(xx, mean_spikes)   # find the scale factor (least squares solution)

vals = np.arange(-.3,.3,.05)
Nth =  getNonlinear1(vals)

plt.errorbar(linear_response_bin+0.025, mean_spikes, yerr=sem_spikes, capsize=3)
plt.plot(vals, Nth*gamma)
plt.title('Estimated nonlinear function', fontsize=16)
plt.xlabel('Linear response component', fontsize=16)
plt.ylabel('Mean spike count', fontsize=16)

## Part 2. Use real data

The neuron here is an H1 neuron in the fly. Data is from the Dayan and Abbott book. The mat file is available here: http://www.gatsby.ucl.ac.uk/~dayan/book/exercises/c1/data/c1p8.mat


### Prepare data

In [None]:
# Download data file

!wget http://www.gatsby.ucl.ac.uk/~dayan/book/exercises/c1/data/c1p8.mat

In [None]:
# Use scipy package to load mat (Matlab) files.

import scipy.io
import numpy as np

mat = scipy.io.loadmat('c1p8.mat')
rho = mat['rho']
stim = mat['stim']

print(np.shape(rho))
print(np.shape(stim))

# The rho and stim in mat file have shape (600000,1).
# Since we don't need a second dimension, we use np.squeeze
# to get rid of the second dimension.

rho = np.squeeze(rho)
stim = np.squeeze(stim)
print(np.shape(rho))
print(np.shape(stim))

t = 2 * np.arange(len(rho)) # Sampling interval is 2 ms.
print(t[0:50])     # print first 50 time values

In [None]:
# Let's first plot the stimulus and spike response

import matplotlib.pyplot as plt

fig, axs = plt.subplots(2, constrained_layout=True, figsize=(6, 5))

axs[0].plot(t[0:1000],stim[0:1000])
axs[0].set_title('Stimulus', fontsize=16)

axs[1].stem(t[0:1000],rho[0:1000], basefmt=" ")
axs[1].set_title('Spikes', fontsize=16)
axs[1].set_xlabel('Time (ms)', fontsize=16)

In [None]:
# Check the distribution of the stimulus

plt.hist(stim, bins=100, density=True, align='left');
plt.xlabel('Stimulus value', fontsize=16)
plt.ylabel('Probability', fontsize=16)

### Estimate linear filter (Spike-triggered average)

In [None]:
# TODO: Write down the code to estimate the fly filter (we'll call it sta again)
# See how we did it for the model neurons
# use kernelSize = 150
# remember that rho tells us the fly spiking response, that is when the fly neuron spiked

kernelSize = 150
totalCount = sum(rho)

#
# FILL IN THE REST OF THE CODE
#

In [None]:
# Plot the STA.

plt.plot(2*np.arange(-kernelSize,0), sta, marker='o')
plt.title('Spike-triggered average', fontsize=16)
plt.xlabel('Time (ms)', fontsize=16)

### Estimate nonlinearity (Extra)

In [None]:
# We can also actually estimate the nonlinearity.
# We can plot the "average" number of spikes fired in response to similar 
# linear responses.

# First we decide on linear response bins...

linear_response = np.convolve(stim, sta, mode='same')

plt.hist(linear_response, bins=50)

In [None]:
# -30000 to 30000 looks like a good range.
linear_response_bin = np.arange(-30000,30000,5000)

mean_spikes = np.zeros(len(linear_response_bin))
sem_spikes = np.zeros(len(linear_response_bin))

for i in range(len(linear_response_bin)):
  # Find when (indexes) the linear response fall in this bin:
  ind_in_bin = np.where(np.logical_and(linear_response>linear_response_bin[i], linear_response<linear_response_bin[i]+5000))
  # Calculate the mean of the spike count over time points whose linear responses are in this bin.
  mean_spikes[i] = np.mean(rho[ind_in_bin])
  sem_spikes[i] = np.std(rho[ind_in_bin], ddof=1) / np.sqrt(len(rho[ind_in_bin]))

plt.errorbar(linear_response_bin+2500, mean_spikes, yerr=sem_spikes, capsize=3)
plt.title('Estimated nonlinear function', fontsize=16)
plt.xlabel('Linear responses', fontsize=16)
plt.ylabel('Mean spike count', fontsize=16)

## Part 3. Questions

Synthetic data:

1. Try changing the linear function (choosing between getLinear1,
getLinear2, getLinear3; see toggle comment). Can we recover each of
the linear filters properly?

2. Try changing the nonlinear function (choose between 
getNonlinear1, getNonlinear2.m) and keeping getLinear1. 
Can we recover the linear filter of the neuron for each 
of these nonlinearities? If not, then why?

3. Keep the original getLinear1 and getNonlinear1 for the model neuron. Try lowering the numsamples in the code above
(example, to 200, 700, and 2000)? Plot the 
filter estimates for these. How good are the filter estimates compared to the actual model filters? Why does the quality of the filter estimate look worse in some of these cases?

Real data:

4. Plot the linear filter that you found for the fly neuron based on the STA. Explain how it is simiilar to or different from the example model filters we used.


