<img src="../resources/cropped-SummerWorkshop_Header.png">  

<h1 align="center">Neuropixels Extracellular Electrophysiology </h1> 
<h2 align="center">Summer Workshop on the Dynamic Brain </h2> 
<h3 align="center">August 2019</h3> 

<img src="../resources/EphysObservatory/neuropixels.png" height="250" width="250"> 

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache
import allensdk.brain_observatory.ecephys.ecephys_session as ecephys_session
%matplotlib notebook

In [2]:
manifest_path = os.path.join('/local1/storage/allensdk_cache/ecephys_project_cache', 'manifest.json')

# manifest_path = os.path.join(
#     "/",
#     "allen",
#     "aibs",
#     "informatics",
#     "nileg",
#     "swdb_ecephys",
#     "cache_dir",
#     "manifest.json"
# )

lims_config = {
    "pg_kwargs": {
        "dbname": "lims2_nileg",
        "host": "aibsdc-dev-db1",
        "port": 5432,
        "user": "limsreader",
        "password": "limsro"
    },
    "app_kwargs": {
        "host": "10.128.50.64:3000"
    }

}

cache = EcephysProjectCache.from_lims(manifest=manifest_path, lims_kwargs=lims_config)

# Exploring an experimental session

In [3]:
sessions = cache.get_sessions()

In [4]:
sessions.head()

Unnamed: 0_level_0,session_type,specimen_id,genotype,gender,age_in_days,project_code,probe_count,channel_count,unit_count,has_nwb,structure_acronyms
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
715093703,brain_observatory_1.1,699733581,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,M,118.0,NeuropixelVisualCoding,6,765,1390,False,"[CA, DG, MB, TH, VIS, VISam, VISl, VISp, VISpm..."
719161530,brain_observatory_1.1,703279284,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,M,122.0,NeuropixelVisualCoding,6,616,1184,False,"[CA, DG, MB, TH, VISal, VISam, VISl, VISp, VIS..."
721123822,brain_observatory_1.1,707296982,Pvalb-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,M,125.0,NeuropixelVisualCoding,2,180,265,True,"[CA, DG, MB, TH, VIS, VISal, VISam, VISl, VISp..."
728680079,brain_observatory_1.1,714089558,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,M,109.0,NeuropixelVisualCoding,6,636,1126,False,"[CA, DG, MB, TH, VIS, VISp, VISpm, VISrl, None]"
729090175,brain_observatory_1.1,715075382,Pvalb-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,F,118.0,NeuropixelVisualCoding,6,594,1157,False,"[CA, DG, TH, VISal, VISam, VISl, VISp, VISpm, ..."


***Here's where I'd like to explore some dimensions of the dataset in terms of stimuli, cre lines, other metadata.  Ultimately I want to use this to select a particular session, rather than pull one arbitrarily***

In [5]:
session_id = 797828357 # for example
session = cache.get_session_data(session_id)



In [6]:
sessions.loc[session_id]

session_type                                      brain_observatory_1.1
specimen_id                                                   776061251
genotype                Pvalb-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt
gender                                                                M
age_in_days                                                           0
project_code                                     NeuropixelVisualCoding
probe_count                                                           6
channel_count                                                       604
unit_count                                                         1076
has_nwb                                                            True
structure_acronyms    [CA, DG, MB, TH, VISal, VISam, VISl, VISp, VIS...
Name: 797828357, dtype: object

## Getting data for a session

Start by exploring tab completion...

In [7]:
# session.

### Units

In [8]:
session.units.head()

AttributeError: 'ElectrodeGroup' object has no attribute 'sampling_rate'

In [None]:
session.units.sampling_rate.unique()

In [None]:
#what is firing_rate? mean firing rate across the entire session? baseline firing rate during some window?

How many units are in this session?

In [None]:
len(session.units)

Which areas (structures) are they from?

In [None]:
print(session.units.structure_acronym.unique())

***Nile add link to brainmap.org for a specific structure acronym***

how many units per area are there?

In [None]:
session.units.structure_acronym.value_counts()

### Spike times

In [None]:
spike_times = session.spike_times

What type of object is this?

In [None]:
type(spike_times)

In [None]:
len(spike_times)

In [None]:
len(session.units)

In [None]:
list(spike_times.keys())[:5]

Use the unit_id for the first unit to get the spike times for that unit. How many spikes does it have in the entire session?

In [None]:
spike_times[session.units.index[0]]

In [None]:
print(len(spike_times[session.units.index[0]]))

Make a raster plot for the first 100 units

In [None]:
plt.figure(figsize=(20,10))
for i in range(100):
    plt.plot(spike_times[session.units.index[i]], np.repeat(i,len(spike_times[session.units.index[i]])), '|')#, color='gray')
# plt.xlim(2000,2250)

### Stimulus presentations

What else can we learn about the session?

In [None]:
session.stimulus_names

In [None]:
stim_pres = session.stimulus_presentations
stim_pres.head()

***need to fix the stimulus names?***

***Explain epochs vs presentations***

In [None]:
#where is the stimulus epochs? replace function below when function exists in sdk

In [None]:
def get_stimulus_epochs():
    stim_presentations = session.stimulus_presentations
    stim_presentations.loc[stim_presentations.stimulus_block.isna(), 'stimulus_block'] = stim_presentations.stimulus_block.max()+1
    stimulus_epochs = pd.DataFrame(columns=('stimulus','start','end'))
    for i,a in enumerate(stim_presentations.stimulus_block.unique()):
        temp = stim_presentations[stim_presentations.stimulus_block==a]
        if temp.stimulus_name.iloc[0] == 'spontaneous_activity':
            for index,row in temp.iterrows():
                if row.duration>90:
                    stimulus_epochs = stimulus_epochs.append(pd.DataFrame([[row.stimulus_name, row.start_time, row.stop_time]],columns=('stimulus','start','end')), ignore_index=True)
        else:
            stimulus_name = temp.stimulus_name.iloc[0]
            start_time = temp.start_time.iloc[0]
            stop_time = temp.stop_time.iloc[-1]
            stimulus_epochs = stimulus_epochs.append(pd.DataFrame([[stimulus_name, start_time, stop_time]],
                                                                  columns=('stimulus','start','end')),ignore_index=True)
    stimulus_epochs.sort_values(by=['start'], inplace=True)
    return stimulus_epochs
    

In [None]:
stimulus_epochs = get_stimulus_epochs()

In [None]:
stimulus_epochs.head()

In [None]:
len(stimulus_epochs.stimulus.unique())

Shade each stimulus with a unique color. The plt.axvspan() is a useful function for this.

In [None]:
plt.figure(figsize=(20,10))
for i in range(100):
    plt.plot(spike_times[session.units.index[i]], np.repeat(i,len(spike_times[session.units.index[i]])), '|', alpha=0.5)#, color='gray')

colors = ['blue','orange','green','red','yellow','purple','magenta','gray','lightblue']
for c,stim_name in enumerate(stimulus_epochs.stimulus.unique()):
    stim = stimulus_epochs[stimulus_epochs.stimulus==stim_name]
    for j in range(len(stim)):
        plt.axvspan(xmin=stim.start.iloc[j], xmax=stim.end.iloc[j], color=colors[c], alpha=0.1)
# plt.xlim(6000,7000)

### Get the running speed

In [None]:
plt.plot(session.running_speed.end_time, session.running_speed.velocity)

Add the running speed to the visualization. 

In [None]:
plt.figure(figsize=(20,10))
for i in range(100):
    plt.plot(spike_times[session.units.index[i]], np.repeat(i,len(spike_times[session.units.index[i]])), '|', alpha=0.5)#, color='gray')
plt.plot(session.running_speed.end_time, (0.4*session.running_speed.velocity)-20)
    
    
colors = ['blue','orange','green','red','yellow','purple','magenta','gray','lightblue']
for c,stim_name in enumerate(stimulus_epochs.stimulus.unique()):
    stim = stimulus_epochs[stimulus_epochs.stimulus==stim_name]
    for j in range(len(stim)):
        plt.axvspan(xmin=stim.start.iloc[j], xmax=stim.end.iloc[j], color=colors[c], alpha=0.1)
        


plt.xlim(6000,8000)

In [None]:
plt.figure(figsize=(20,4))
i=0
plt.plot(spike_times[session.units.index[i]], np.repeat(i,len(spike_times[session.units.index[i]])), '|', alpha=0.5)#, color='gray')
plt.plot(session.running_speed.end_time, (0.4*session.running_speed.velocity)-20)
    
    
colors = ['blue','orange','green','red','yellow','purple','magenta','gray','lightblue']
for c,stim_name in enumerate(stimulus_epochs.stimulus.unique()):
    stim = stimulus_epochs[stimulus_epochs.stimulus==stim_name]
    for j in range(len(stim)):
        plt.axvspan(xmin=stim.start.iloc[j], xmax=stim.end.iloc[j], color=colors[c], alpha=0.1)
        


plt.xlim(6000,8000)

In [None]:
plt.figure(figsize=(20,4))
i=40
plt.plot(spike_times[session.units.index[i]], np.repeat(0,len(spike_times[session.units.index[i]])), '|', alpha=0.5)#, color='gray')
plt.plot(session.running_speed.end_time, (0.4*session.running_speed.velocity)-20)
    
    
colors = ['blue','orange','green','red','yellow','purple','magenta','gray','lightblue']
for c,stim_name in enumerate(stimulus_epochs.stimulus.unique()):
    stim = stimulus_epochs[stimulus_epochs.stimulus==stim_name]
    for j in range(len(stim)):
        plt.axvspan(xmin=stim.start.iloc[j], xmax=stim.end.iloc[j], color=colors[c], alpha=0.1)
        


plt.xlim(6000,8000)

In [None]:
plt.figure(figsize=(20,4))
plt.plot(spike_times[session.units.index[40]], np.repeat(5,len(spike_times[session.units.index[40]])), '|', alpha=0.5)#, color='gray')
plt.plot(spike_times[session.units.index[0]], np.repeat(0,len(spike_times[session.units.index[0]])), '|', alpha=0.5)#, color='gray')
plt.plot(session.running_speed.end_time, (0.4*session.running_speed.velocity)-20)
plt.xlim(6000,8000)

# Exploring units

## Plotting and sorting units

### create a function to plot the raster plot

*** replace with call to api?***

In [None]:
import allensdk.brain_observatory.ecephys.visualization as ecvis
# ecvis.raster_plot()

In [None]:
def plot_raster(spike_times, start, end):
    num_units = len(spike_times)
    ystep = 1 / num_units

    ymin = 0
    ymax = ystep

    for unit_id, unit_spike_times in spike_times.items():
        unit_spike_times = unit_spike_times[np.logical_and(unit_spike_times >= start, unit_spike_times < end)]
        plt.vlines(unit_spike_times, ymin=ymin, ymax=ymax)

        ymin += ystep
        ymax += ystep
        

Select a single stimulus presentation

In [None]:
drifting_gratings_presentation_onsets = session.stimulus_presentations.loc[
    session.stimulus_presentations["stimulus_name"] == "drifting_gratings", 
    "start_time"
].values
start, end = drifting_gratings_presentation_onsets[:2]

In [None]:
plot_raster(session.spike_times, start, end)
plt.xlabel('Time (sec)')
plt.ylabel('Units')
plt.tick_params(axis="y", labelleft=False, left=False)
plt.show()

### arrange neurons by their firing rate

*** Comments: 1) remove hide the quality column; 2) remove sampling rate column - it is the same!  3) remove valid_data ***

In [None]:
session.units.sort_values(by="firing_rate", ascending=False).head()

In [None]:
by_fr = session.units.sort_values(by="firing_rate", ascending=False)
spike_times_by_firing_rate = {
    uid: session.spike_times[uid] for uid in by_fr.index.values
}

plot_raster(spike_times_by_firing_rate, start, end)
plt.ylabel('Units')
plt.xlabel('Time (sec)')
plt.show()


## QC Metrics

### show the qc metrics of differnet units

In [None]:
session.units.head()

### narrow down the session parameters

In [None]:
session_params = session.units.loc[:, ["structure_acronym", "probe_id","firing_rate", "isi_violations", "snr",'probe_vertical_position']]
session_params.head()

### sort the current session parameters by ISI violations

In [None]:
session_params.sort_values(by="isi_violations").head()

### describe what is the isi violation? write the down the equation for that!
Metrics

1) ISI violation: equation, then plot spike train with the clear spike violation!
2) from Josh: get the ISI, FR etc as metrics where the data is good! what are the thresholds for the good data? SNR + ISI violations

### plot the ISI violation distribution

In [None]:
plt.hist(np.log10(1 + session_params["isi_violations"].values), bins=100)
plt.xlabel('log10(1 + isi_violations)')
plt.ylabel('unit count')
plt.title('distribution of the isi violations')
plt.show()

In [None]:
good_units = session_params[
    (session_params["isi_violations"] < 0.2)
    & (session_params["snr"] > 2)
]

print('Number of units with reasonable ISI and SNR:')
print(good_units.shape[0])

## Locations of units

### Brain structures

#### describe the area with reasonable qc metrics

In [None]:
good_units.structure_acronym.value_counts()


#### plot the firing rate of the units with regards to structures


In [None]:
gb = session.units.groupby("structure_acronym")

structures = []
data = []

for group in gb:
    structure, current_data = group
    structures.append(structure)
    data.append(current_data["firing_rate"].values)

axs = plt.gca()
    
plt.violinplot(data)

axs.set_xticks(np.arange(len(structures))+1)
axs.set_xticklabels(structures)
plt.ylabel('Firing-rate (Hz)')
plt.show()


### Locations on probe

In [None]:
plt.subplots()
plt.hist(session_params["probe_vertical_position"].values, bins=100)
plt.xlabel('probe_vertical_position (mm)')
plt.ylabel('unit count')
plt.show()

## Unit waveforms

We store precomputed mean waveforms for each unit in the `mean_waveforms` attribute on the `EcephysSession` object. This is a dictionary which maps unit ids to xarray dataarrays. These have channel and time (seconds, aligned to the detected event times) dimensions. The data values are millivolts, as measured at the recording site.

In [None]:
waveforms = session.mean_waveforms
type(waveforms)

### Plot all waveforms for one unit

In [None]:
unit = session.units.index.values[400]
wf = session.mean_waveforms[unit]

fig, ax = plt.subplots()
plt.pcolormesh(wf, X=wf.time)
plt.xlabel('Time steps')
plt.ylabel('Channel #')

We can figure out where each channel is located in the brain using the function `ecephys_session.intervals_structures`, which will identify channels that serve as reference points for the boundaries between identified brain regions.

In [None]:
# pass in the list of channels from the CSD
channels = session.channels.loc[csd.channel]
structure_acronyms, intervals = ecephys_session.intervals_structures(channels)
interval_midpoints = [ (aa + bb) / 2 for aa, bb in zip(intervals[:-1], intervals[1:])]
print(structure_acronyms)
print(interval_midpoints)

In [None]:
plt.pcolormesh(wf, X=wf.time)
plt.colorbar(ax=ax)

ax.set_xlabel("time (s)")
ax.set_yticks(intervals)
ax.set_yticks(interval_midpoints, minor=True)
ax.set_yticklabels(structure_acronyms, minor=True)
plt.tick_params("y", which="major", labelleft=False, length=40)

plt.show()

Let's see if this matches the structure information saved in the units table:

In [None]:
session.units.loc[unit]

### plot peak channels for all units recorded in the dentate gyrus (DG)

In [None]:
fig, ax = plt.subplots()

th_unit_ids = session.units[session.units["structure_acronym"] == "DG"].index.values

peak_waveforms = []

for unit_id in th_unit_ids:

    peak_ch = session.units.loc[unit_id, "peak_channel_id"]
    unit_mean_waveforms = session.mean_waveforms[unit_id]

    peak_waveforms.append(unit_mean_waveforms.loc[{"channel_id": peak_ch}])
    
    
time_domain = unit_mean_waveforms["time"]

peak_waveforms = np.array(peak_waveforms)
plt.pcolormesh(peak_waveforms)

### show the pca of the average waveforms to make sure the units make sense

***Check whether we want this here or in exercises***

In [None]:
# apply pca to the averaged waveforms

# from sklearn import decomposition
# pca = decomposition.PCA(n_components=2)
# pca.fit(peak_waveforms)


In [None]:
# fig, ax = plt.subplots()
# plt.plot(time_domain, pca.components_.T)
# plt.title('2 PCA components')


# print('Explained variance of 2 components')
# print(pca.explained_variance_ratio_)

# Spike histograms and stimulus coding

***Maybe modify to use stim rather than spontaneous to transition into stim coding?***

### create the histograms

In [None]:
spon = session.stimulus_presentations.loc[
    session.stimulus_presentations["stimulus_name"] == "spontaneous_activity", 
    ["start_time", "stop_time"]
]

In [None]:
first_spon_id = spon.index.values[0]
first_spon_duration = spon.loc[first_spon_id, "stop_time"] - spon.loc[first_spon_id, "start_time"]

# 1 - sec
time_step = 1 / 100
time_domain = np.arange(0.0, first_spon_duration + time_step, time_step)

histograms = session.presentationwise_spike_counts(
    bin_edges=time_domain,
    stimulus_presentation_ids=spon.index,
    unit_ids=None
)

print(histograms)

### plot the firing rate of neurons in different units

In [None]:
# spike_counts = histograms.spike_counts.values
time = histograms["time_relative_to_stimulus_onset"]

hist_train_1 = histograms[0,:,0]
hist_train_2 = histograms[0,:,1]

fig, ax = plt.subplots()

plt.plot(time, hist_train_1, '.')
plt.plot(time, hist_train_2, '.')
plt.xlabel('Time (sec)')

### compute the mean of the histograms

In [None]:
mean_histograms = histograms.mean(dim="stimulus_presentation_id")

In [None]:
mean_histograms.coords

In [None]:
import xarray.plot as xrplot
xrplot.imshow(darray=mean_histograms, x="time_relative_to_stimulus_onset",
                                      y="unit_id")

### compute the correlation matrix

***Move this part to exercises?? operate on means or by trial??***

***also this is very slow!***

In [None]:
# spike_counts = mean_histograms
# num_units = spike_counts.shape[1]

# correlations = np.zeros((num_units, num_units))

# for ii in range(num_units):
#     for jj in range(num_units):
#         # normalize spike trains before computation
#         spike_train_1=spike_counts[:, ii]
# #        spike_train_1=(spike_train_1-np.mean(spike_train_1))/np.std(spike_train_1)/len(spike_train_1)
#         spike_train_2=spike_counts[:, jj]
# #        spike_train_2=(spike_train_2-np.mean(spike_train_2))/np.std(spike_train_2)/len(spike_train_2)
#         correlations[ii, jj] = np.correlate(spike_train_1, spike_train_2)
# #        np.correlate(spike_counts[:, ii], spike_counts[:, jj])

***add structure boundaries to plot here?***

In [None]:
# fig, ax = plt.subplots()
# plt.imshow(np.log10(correlations+1))

## Stimulus coding

***Construct a basic tuning curve***

## Local Field Potential (LFP)

The final aspect of a Neuropixels probe recording we will investigate is the local field potential (LFP). An LFP signal is a direct recordings of extracellular voltage from which individual spike contributions have been removed by low-pass filtering. The remaining signal reflects the population activity of a large number of cells in the vicinity of the probe, primarily through the electrical field effects of synaptic currents (along with other trans-membrane currents).

LFP can be especially informative for understanding rhythmic activity or oscillations in neural circuits, which can be identified by some simple time-series analysis of the LFP signals.

### Accessing data

We'll start by loading the LFP data from one of the probes in our session.

We need to provide this function with a probe id, which we can pull out of the `session.probes` table. 

(Note that the "id" column is the index of the dataframe, and thus must be accessed differently than other columns.)

In [None]:
probe_id = session.probes.index[0]
lfp = session.get_lfp(probe_id)
print(lfp)

### Plot the LFP time series

To visualize this data, we'll first use the built-in xarray plotting to generate a quick plot. This is too much data to plot all at once, so we select a subset first. Just as in pandas, we use the `loc` property, but since xarray has named dimensions, we can specify our selections by name rather than by order, using a dict.

In [None]:
channel = lfp.channel[0]
subset = lfp.loc[dict(channel=channel, time=slice(5,15))]

plt.figure(figsize=(12,3))
subset.plot()

We might also want to visualize a specific frequency band by filtering. To do this we'll want to convert our data into standard numpy arrays for easier processing using the DataArray object's `values` property.

In [None]:
t = lfp.time.values
v = lfp.isel(channel=0).values

In [None]:
import scipy.signal
freq_window = (4, 15)
filt_order = 3
fs = 1/(t[1]-t[0])
b, a = scipy.signal.butter(filt_order, freq_window, btype='bandpass', fs=fs)
v_alpha = scipy.signal.lfilter(b, a, v)


window = [5, 15]
idx = np.logical_and(t>=window[0], t<window[1])
plt.figure(figsize=(12,3))
plt.plot(t[idx], v[idx])
plt.plot(t[idx], v_alpha[idx],'k')

## Spectral analysis


Next we're going to analyze some spectral properties of this signal using the `scipy.signal` library. "Spectral" refers to decomposing a signal into a sum of simpler components identified by their frequencies. The set of frequencies of the components forms a *spectrum* that tells us about the complete signal. You can see a full list of spectral analysis functions in scipy here: https://docs.scipy.org/doc/scipy/reference/signal.html#spectral-analysis

### Power spectral density (PSD)

We first import the package, and inspect the `periodogram` function, which estimates the size of the different frequency components of the signal.

** Note: maybe we want to compute this directly from an FFT? but only if that concept is already meaningful, so maybe not...**

In [None]:
import scipy.signal
help(scipy.signal.periodogram)

There are a number of options that we won't go into here for refining the analysis. The one piece of information we do need is `fs`, the sampling frequency. If we used the default value `fs=1.0` our results would not match the true frequencies of the signal.

In [None]:
fs = 1/(t[1]-t[0])
window = [10, 100]
idx = np.logical_and(t>=window[0], t<window[1])

f, psd = scipy.signal.periodogram(v[idx], fs)

We'll plot the power spectrum on a semilog plot, since power can vary over many orders of magnitude across frequencies.

In [None]:
plt.figure(figsize=(6,3))
plt.semilogy(f,psd,'k')
plt.xlim((0,100))
plt.yticks(size=15)
plt.xticks(size=15)
plt.ylabel('Power ($uV^{2}/Hz$)',size=20)
plt.xlabel('Frequency (Hz)',size=20)
plt.show()

We see that this representation of the power spectrum is extremely noisy. Luckily, many people have come up with solutions to this problem. Scipy includes a function for Welch's method, which averages out noise by computing many estimates of the power spectrum from overlapping windows of the data. You can find some more references for this approach in the Scipy documentation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.welch.html#scipy.signal.welch

In [None]:
f, psd = scipy.signal.welch(v[idx], fs, nperseg=1000)

plt.figure(figsize=(6,3))
plt.semilogy(f,psd,'k')
plt.xlim((0,100))
plt.yticks(size=15)
plt.xticks(size=15)
plt.ylabel('Power ($uV^{2}/Hz$)',size=20)
plt.xlabel('Frequency (Hz)',size=20)
plt.show()

### Calculate and plot the time-frequency profile ("spectrogram")

We might also be interested in how the frequency content of the signal varies over time. In a neural context, power in different frequency bands is often linked to specific types of processing, so we might explore whether changes in the spectrum coincide with specific behaviors or stimuli.

The *spectrogram* is essentially an estimate of the power spectrum computed in a sliding time window, producing a 2D representation of the signal power across frequency and time.

In [None]:
window = [10, 20]
idx = np.where(np.logical_and(t>=window[0],t<window[1]))[0]

f, t_spec, spec = scipy.signal.spectrogram(v[idx], fs=fs, window='hanning',
                            nperseg=1000, noverlap=1000-1, mode='psd')
# Scipy assumes our signal starts at time=0, so we need to provide the offset
t_spec = t_spec + t[window[0]]

We'll use the matplotlib `pcolormesh` function to visualize this data as an image. We can pass this function grids of x and y coordinates to get the axis labeling right. We also log-transform the power spectrum and restrict to frequencies less than 100 Hz.

In [None]:
fmax = 80
x_mesh, y_mesh = np.meshgrid(t_spec, f[f<fmax])
plot_data = np.log10(spec[f<fmax])

We'll plot the spectrum together with the raw signal in subplots. Note that we explicitly set the x-axis limits to align the plots. (Alternatively, it's possible to directly couple the limits of different subplots.)

In [None]:
from matplotlib import cm
plt.figure(figsize=(10,4))

plt.subplot(2,1,1)
plt.pcolormesh(x_mesh, y_mesh, plot_data, cmap=cm.jet)
plt.xlim(window)
plt.ylabel('Frequency (Hz)')

plt.subplot(2,1,2)
plt.plot(t[idx], v[idx], 'k')
plt.xlim(window)
plt.xlabel('Time (s)')
plt.ylabel('Voltage (a.u.)')
plt.show()

## Current source density (CSD) analysis

Physically, the LFP is made up of the electric fields from specific current sources (or sinks) in brain tissue, namely individual trans-membrane currents. Under certain simplifying assumptions, this transformation of spatial current distribution into field potential can be inverted to infer the distribution of currents underlying a measurement. This is called the current source density, or CSD. Spatial properties of the LFP are generally better studied in this representation.

We have pre-calculated estimates of CSD for each probe during a subset of stimulus presentations, which we access below. Note that the CSD array contains data for 186 channels (half the total), in contrast to the LFP which is only provided for approximately one quarter of the contacts.

In [None]:
csd = session.get_current_source_density(probe_id)
csd

***Maybe move this elsewhere***

We can figure out where each LFP channel is located in the brain using the function `ecephys_session.intervals_structures`, which will identify channels that serve as reference points for the boundaries between identified brain regions.

In [None]:
# pass in the list of channels from the CSD
channels = session.channels.loc[csd.channel]
structure_acronyms, intervals = ecephys_session.intervals_structures(channels)
interval_midpoints = [ (aa + bb) / 2 for aa, bb in zip(intervals[:-1], intervals[1:])]
print(structure_acronyms)
print(interval_midpoints)

In [None]:
fig, ax = plt.subplots()

xmesh, ymesh = np.meshgrid(csd.time, range(len(csd.channel)))
plt.pcolormesh(xmesh, ymesh, csd, vmin=-1e5, vmax=1e5)
plt.colorbar(ax=ax)

ax.set_xlabel("time (s)")
ax.set_yticks(intervals)
ax.set_yticks(interval_midpoints, minor=True)
ax.set_yticklabels(structure_acronyms, minor=True)

# make the long divider lines between intervals
plt.tick_params("y", which="major", labelleft=False, length=40)
plt.show()

***Equivalent plot using xarray***

In [None]:
fig, ax = plt.subplots()
csd.coords['ichannel']=('channel',range(len(csd.channel)))
csd.plot(x='time', y='ichannel', robust=True, cmap=cm.jet)

ax.set_yticks(intervals)
plt.tick_params("y", which="major", labelleft=False, length=40)
ax.set_yticks(interval_midpoints, minor=True)
ax.set_yticklabels(structure_acronyms, minor=True)
plt.show()