In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

## Loading the data

In [None]:
data_url = 'https://raw.githubusercontent.com/regevti/sensory-systems-workshop/master/retina/retina_data.pkl'
data = pd.read_pickle(data_url)

## Data Exploration

In [None]:
data['info']

In [None]:
data['stimulus']

In [None]:
rec_id = 0 # recording ID
spikes = data['spikes'][:, rec_id]
start_time = data['stimulus'].loc[rec_id, 'onset']

##Exercises

1. Create a variable "frame_dt" and store in it the time duration of a single frame in the first record.


In [None]:
frame_dt = data['stimulus'].loc[rec_id, 'frame']

2. Create a variable "n_frames" and store in it the number of frames shown in the first record




In [None]:
n_frames = data['stimulus'].loc[rec_id, 'Nframes']

3. Create a variable called "rec_duration" and store in it the calculated duration of the 1st record.

In [None]:
rec_duration = frame_dt * n_frames

4. Which neuron fired the most during the 1st record?

In [None]:
# your code goes here

##Raster Plot

> There are 2 ways for plotting the raster plot:
1. creating an image with pixels as spikes and showing it using plt.imshow()
2. scatter plot


In [None]:
# raster plotting using plt.imshow()

end_time = start_time + 10
raster = []
for neuron in spikes:
  bins = np.arange(start_time, end_time, 0.01)
  hist, bins = np.histogram(neuron, bins=bins, range=(start_time, end_time))
  raster.append(hist)

raster = np.vstack(raster)
plt.figure(figsize=(20, 7))
plt.imshow(raster, aspect='auto', cmap='Greys', interpolation='none')
plt.xlabel('time [sec]')
plt.ylabel('Ganglion #')
plt.title(f'Raster Plot for Record {rec_id + 1}')
plt.show()

In [None]:
# raster plot using scatter plot

plt.figure(figsize=(20,7))
for neuron_id, neuron_spike_times in enumerate(spikes):
  plt.scatter(neuron_spike_times, neuron_id * np.ones(neuron_spike_times.shape), c='black', s=0.1)

plt.xlim([0, end_time])
plt.xlabel('time [sec]')
plt.ylabel('Ganglion #')
plt.title(f'Raster Plot for Record {rec_id + 1}')
plt.show()

## Spike Triggered Average

In [None]:
!wget https://raw.githubusercontent.com/regevti/sensory-systems-workshop/master/retina/ran1.bin

###Reading the Stimulus

In [None]:
# calculate number of squares/bars along each axis
nx = int(data['stimulus'].loc[rec_id, 'x'] / data['stimulus'].loc[rec_id, 'dx'])
ny = int(data['stimulus'].loc[rec_id, 'y'] / data['stimulus'].loc[rec_id, 'dy'])

# number of squares/bars in a single frame
frame_squares = nx * ny

total_squares = frame_squares * n_frames


def bytes_to_bits(byts):
  """Helper function for converting bytes array to bits (format *ubit1)"""
  b = np.unpackbits(byts)
  c = np.reshape(b, (int(b.size / 8), 8))  # reshape  bits array to bytes array
  d = np.flip(c, 1)  # flip each byte
  e = np.reshape(d, d.size)  # reshape bytes array to bits array
  e.dtype = np.int8  # change returned array dtype uint8 --> int8
  return e


def read_stimulus(filename):
  # read the amount of squares needed for the record
  rand_bytes = np.fromfile(filename, dtype='uint8', count=int(total_squares/8))
  r = bytes_to_bits(rand_bytes)

  # convert [0, 1] values to [-1, 1]
  r = 2 * r - 1

  # reshaping the array, such that each column represents a frame that was displayed to the retina.
  r = r.reshape((n_frames, frame_squares)).T
  print(f'stimulus shape: {r.shape}')
  return r


rand_stim = read_stimulus('/content/ran1.bin')

###Stimulus Example

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(4*5, 5))
for i, ax in enumerate(axes):
  ax.imshow(rand_stim[:,i].reshape(nx, ny).T, aspect='auto', cmap='Greys', interpolation='none')

###Spike Triggered Average - Single Ganglion

In [None]:
# parameters initialization

window_length = 0.5 # seconds
window_n_frames = int(np.round(window_length / frame_dt))

# calculate the space vectors in mm
dx , dy = int(data['stimulus'].loc[rec_id, 'dx']), int(data['stimulus'].loc[rec_id, 'dy'])
pixel_size = int(data['stimulus'].loc[rec_id, 'pixelsize']) / 1000 # mm
sx = np.arange(0, nx) * dx * pixel_size
sy = np.arange(0, ny) * dy * pixel_size

# calculate the time vector for STA
t = np.arange(-window_n_frames, 0) * frame_dt

In [None]:
def spike_triggered_average(gang_id):
  """Calculate the STA matrix for a given ganglion.
  Output Matrix shape: [squares, window_frames]"""
  gang_spikes = spikes[gang_id]
  # create the bins vector. Each bin represent the time range of a frame
  stim_frames_time = np.arange(0, n_frames) * frame_dt + start_time
  # histogram for spikes during each frame
  spikes_frames, bins = np.histogram(gang_spikes, bins=stim_frames_time)
  # delete the data from the first window, since we have no information of what 
  # was projected to the ganglion cells before the stimulus onset
  spikes_frames[:window_n_frames] = 0
  # Find the events (frames during which the neuron fired)
  event_idx = np.where(spikes_frames > 0)[0]

  sta = np.zeros((frame_squares, window_n_frames))
  # Find the indices of the time window preceding the event
  for evi in event_idx:
    widx = np.arange(evi - window_n_frames, evi)
    sta = sta + rand_stim[:, widx]

  # normalizing the results
  sta = sta / len(event_idx)
  return sta


def plot_spatio_temporal(gang_id, ax=None):
  """Plot the spatio-temporal STA for a ganglion before strike"""
  sta = spike_triggered_average(gang_id)
  if ax is None:
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))

  ax.imshow(sta, cmap='gray', aspect=t.size/sx.size, 
            extent=[t[0], t[-1], sx[0], sx[-1]])
  ax.set_xlabel('Time before spike [sec]')
  ax.set_ylabel('space [mm]')
  ax.set_title(f'Ganglion #{gang_id+1}')

In [None]:
sta0 = spike_triggered_average(1)
sta0.shape

In [None]:
# spatio-temporal plot
plot_spatio_temporal(0)

In [None]:
# spatio-temporal plot
plot_spatio_temporal(1)

In [None]:
# All ganglions
n_ganglions = data['info']['Ncell']
cols = 6
rows = int(np.ceil(n_ganglions/cols))
fig, axes = plt.subplots(rows, cols, figsize=(20, 3*rows))
axes = axes.flatten()
for i in range(n_ganglions):
  plot_spatio_temporal(i, axes[i])  
fig.tight_layout()