In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache
import scipy.stats as stats

In [None]:
cache_dir = '/media/seb/ee7d6f3e-3390-444a-b0b3-131b80f2a7f8/ecephys_cache_dir/'
brain_area = 'VISp'    

# Load cache file
print('Loading cache dir, will download if it does not already exist')
manifest_path = os.path.join(cache_dir, 'manifest.json')
cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

# Get sessions
sessions = cache.get_session_table()

filtered_sessions = sessions[(sessions.session_type == 'brain_observatory_1.1') & \
                             ([brain_area in acronyms for acronyms in 
                               sessions.ecephys_structure_acronyms])]

filtered_sessions_idxs = filtered_sessions.index.values
session_id = filtered_sessions_idxs[1]

# Load up a particular recording session
print('\n\nLoading session data for session', session_id)
session = cache.get_session_data(session_id)

# Get unit ids in brain area
units = session.units[session.units['ecephys_structure_acronym'] == brain_area]

print('\n\nGet grating table')
grating_stim_table = session.get_stimulus_table("drifting_gratings")



In [None]:
# Get spikes aligned to each frame onset
spikes = session.presentationwise_spike_times(
    stimulus_presentation_ids=grating_stim_table.index.values,
    unit_ids=units.index.values
)

# Add count column to spikes dataframe
spikes["count"] = np.zeros(spikes.shape[0])
# For each stimulus presentation and neuron, count all the spikes
spikes_count = spikes.groupby(["stimulus_presentation_id", "unit_id"]).count()

# Now reshape into a matrix with rows as presentation ids and columns as unit ids
stimulus_x_unit_id = pd.pivot_table(
    spikes_count,
    values="count",
    index="stimulus_presentation_id",
    columns="unit_id",
    fill_value=0.0,
    aggfunc=np.sum
)

# Append missing rows for which no neuron fired 
missing_rows = list(set(grating_stim_table.index.values).difference(stimulus_x_unit_id.index.values))
for row in missing_rows:
    empty_row = stimulus_x_unit_id.iloc[0].copy()
    empty_row[:] = 0
    stimulus_x_unit_id.loc[row] = empty_row

# And sort by the index (presentation id)
stimulus_x_unit_id = stimulus_x_unit_id.sort_index()

# Now add frame column
stimulus_x_unit_id['orientation'] = [v if type(v)==float else -1 for v in grating_stim_table['orientation'].values]

# Sort rows by frame number
stimulus_x_unit_id = stimulus_x_unit_id.sort_values('orientation')



In [None]:
def plot_orientation (ax, orientations, orientation_tuning_curve):    
    orientations.append(orientations[0])
    orientation_tuning_curve.append(orientation_tuning_curve[0])
    

    ax.plot(orientations, orientation_tuning_curve, c='gray')
    ax.set_theta_zero_location("N") 
    ax.set_theta_direction(-1)
    ax.tick_params(labelsize=20)
    #ax.set_rticks([])

fig, axs = plt.subplots(nrows=1, ncols=2, dpi=100, subplot_kw={'projection': 'polar'}, figsize=[6, 4])
    
for unit_id, ax in zip(units.index.values[[47, 0]], axs):
    tuning_curve = {}

    for spikes, ori in zip(stimulus_x_unit_id[unit_id], stimulus_x_unit_id['orientation']):
        if ori == -1:
            continue

        if not ori in tuning_curve:
            tuning_curve[ori] = []

        tuning_curve[ori].append(spikes)

    orientations = [np.deg2rad(d) for d in tuning_curve.keys()]
    orientation_tuning_curve = [np.mean(v) for v in tuning_curve.values()]

    plot_orientation (ax, orientations, orientation_tuning_curve)
plt.tight_layout()
#save_plot(1, 'v1_tuning_curves')
plt.show()