In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams.update({'font.size': 12})

from scipy import signal 
from scipy import stats
import itertools
import seaborn as sns
import statsmodels.api as sm
import random

In [2]:
import pickle
import json
json_open = open('../../dj_local_conf.json', 'r') 
config = json.load(json_open)

import datajoint as dj
dj.config['database.host'] = config["database.host"]
dj.config['database.user'] = config ["database.user"]
dj.config['database.password'] = config["database.password"]
dj.conn().connect()

from pipeline import lab, get_schema_name, experiment, foraging_model, ephys, foraging_analysis, histology, ccf
from pipeline.plot import unit_psth
from pipeline.plot.foraging_model_plot import plot_session_model_comparison, plot_session_fitted_choice
from pipeline import psth_foraging
from pipeline import util
from pipeline.model import bandit_model

  from cryptography.hazmat.backends import default_backend


Connecting pochen@datajoint.mesoscale-activity-map.org:3306


In [4]:
# select units
def select_unit_qc_region_mouse_session(all_unit_qc, region_annotation="Prelimbic%", mouse="HH09", session=47):
    unit_qc_region = (all_unit_qc * histology.ElectrodeCCFPosition.ElectrodePosition * ccf.CCFAnnotation) & 'annotation LIKE "{}"'.format(region_annotation)
    return unit_qc_region & (lab.WaterRestriction & 'water_restriction_number = "{}"'.format(mouse)) & 'session = {}'.format(session)


def session_stats_in_region(all_unit_qc, region_annotation):
    # get all units in the region
    unit_qc_region = (all_unit_qc * histology.ElectrodeCCFPosition.ElectrodePosition * ccf.CCFAnnotation) & 'annotation LIKE "{}"'.format(region_annotation)
    print('total number of units in the region {}: {}'.format(region_annotation, len(unit_qc_region)))
    # check subregions
    for annotation in np.unique(unit_qc_region.fetch('annotation')):
        print(' {}: {} units'.format(annotation, len(unit_qc_region & 'annotation = "{}"'.format(annotation))))
    # get unique session in the region
    print(' unique sessions: {}'.format(np.unique(unit_qc_region.fetch('session'))))
    for session in np.unique(unit_qc_region.fetch('session')):
        subject_id = np.unique((unit_qc_region & 'session={}'.format(session)).fetch('subject_id'))
        if len(subject_id) != 1:
            raise ValueError('{} mouse was fetched for the given session!'.format(len(subject_id)))
        else:
            mouse = (lab.WaterRestriction() & 'subject_id={}'.format(subject_id[0])).fetch('water_restriction_number')[0]
            unit_qc_region_mouse_session = select_unit_qc_region_mouse_session(
                                                all_unit_qc, 
                                                region_annotation=region_annotation, 
                                                mouse=mouse, 
                                                session=session)
            # get units in the region-mouse-session
            keys2units = unit_qc_region_mouse_session.fetch('KEY')
            # get trials in the region-mouse-session
            trials_all = psth_foraging.TrialCondition().get_trials('foraging_LR_all_noearlylick') & keys2units
            print('  mouse {}, session {}: {} units, {} trials'.format(mouse, session, 
                                                                         len(unit_qc_region_mouse_session), 
                                                                         len(trials_all)))
    

# fetch keys
def gen_keys2units(unit_qc_region_mouse_session):
    
    # get all keys
    keys2units = unit_qc_region_mouse_session.fetch('KEY')
    print('num of units: {}'.format(len(keys2units)))

    return keys2units

In [6]:
# fetch unit data
# after unit qc
foraging_session = experiment.Session & 'username = "hh"'
all_unit_qc = (ephys.Unit * ephys.ClusterMetric * ephys.UnitStat) & foraging_session & 'presence_ratio > 0.95' & 'amplitude_cutoff < 0.1' & 'isi_violation < 0.5' & 'unit_amp > 70'
dj.U('annotation').aggr(((ephys.Unit & all_unit_qc.proj()) * histology.ElectrodeCCFPosition.ElectrodePosition) * ccf.CCFAnnotation, count='count(*)').fetch(format='frame', order_by='count desc')[:]

Unnamed: 0_level_0,count
annotation,Unnamed: 1_level_1
"Lateral septal nucleus, rostral (rostroventral) part",801
Caudoputamen,473
"Secondary motor area, layer 6a",401
"Secondary motor area, layer 5",294
"Prelimbic area, layer 5",286
...,...
Ethmoid nucleus of the thalamus,1
"Retrosplenial area, ventral part, layer 6a",1
Triangular nucleus of septum,1
"Prelimbic area, layer 6b",1


In [8]:
# check session stats for a region
# fetch data from brain region
region_ann_lut = {
    # premotor
    'ALM': "Secondary motor area%",
    # isocortex, PFC
    'PL': "Prelimbic%",
    'ACA': "Anterior cingulate area%",
    'ILA': "Infralimbic%",
    'ORB': '%orbital%',
    'FRP': '%frontal%',
    'RSP': "Retrosplenial area%",
    # thalamus
    'VM': 'Ventral medial%',
    'MD': 'Mediodorsal%',
    # striatum
    'LSN': "Lateral septal nucleus%",
    'CP': "Caudoputamen%",
    'NA': "Nucleus accumbens%",
    'striatum': "striatum%",
    # Pallidum
    'PALv': "Substantia innominata%",
    # Olfactory
    'OLF': "%olfactory%",
}

# region = 'ALM'
# region_annotation = region_ann_lut[region]
# session_stats_in_region(all_unit_qc, region_annotation)


# ALM
# region, mouse, session = 'ALM', "HH13", 36
# region, mouse, session = 'ALM', "HH09", 50

# PL
region, mouse, session = 'PL', "HH13", 36
# region, mouse, session = 'PL', "HH09", 47
# ACA
# region, mouse, session = 'ACA', "HH13", 36
# region, mouse, session = 'ACA', "HH09", 50
# ILA
# region, mouse, session = 'ILA', "HH08", 49
# region, mouse, session = 'ILA', "HH08", 50
# ORB
# region, mouse, session = 'ORB', "HH13", 36
# FRP
# region, mouse, session = 'FRP', "HH13", 33
# RSP
# region, mouse, session = 'RSP', "HH09", 59  # no ipsi choice under positive deltaQ

# VM
# region, mouse, session = 'VM', "HH13", 42
# region, mouse, session = 'VM', "HH13", 43
# MD
# region, mouse, session = 'MD', "HH09", 60
# region, mouse, session = 'MD', "HH09", 57  # no ipsi choice under positive deltaQ

# LSN
# region, mouse, session = 'LSN', "HH13", 36
# region, mouse, session = 'LSN', "HH08", 50
# CP
# region, mouse, session = "CP", "HH13", 45
# region, mouse, session = 'CP', "HH08", 51
# NA
# region, mouse, session = 'NA', "HH13", 37
# region, mouse, session = 'NA', "HH09", 50
# striatum
# region, mouse, session = 'striatum', "HH08", 49

# PALv
# region, mouse, session = 'PALv', "HH13", 45
# region, mouse, session = 'PALv', "HH09", 57  # no ipsi choice under positive deltaQ

# olf
# region, mouse, session = 'olf', "HH09", 50


# select qc in region, by mouse and session
unit_qc_region_mouse_session = select_unit_qc_region_mouse_session(
                                    all_unit_qc, 
                                    region_annotation=region_ann_lut[region], 
                                    mouse=mouse, 
                                    session=session)
print('total number of units in {} by mouse {} by session {}: {}'.format(region, mouse, session, len(unit_qc_region_mouse_session)))

# fetch keys
keys2units = gen_keys2units(unit_qc_region_mouse_session)

# laterality
laterality_condition = {
        'right': 'ml_location > 0',
        'left': 'ml_location < 0'
}
ml_locations = np.unique((ephys.ProbeInsertion.InsertionLocation & keys2units).fetch('ml_location'))
hemi_choose = 'left'
if len(ml_locations) == 0:
    raise Exception('No ProbeInsertion.InsertionLocation available')
elif (ml_locations > 0).any() and (ml_locations < 0).any():
    print('The specified units belongs to both hemispheres, use pre-specified laterality')
    hemi = hemi_choose
    print(' manually choose laterality: {}'.format(hemi))
elif (ml_locations > 0).all():
    hemi = 'right'
elif (ml_locations < 0).all():
    hemi = 'left'
else:
    assert (ml_locations == 0).all()  # sanity check
    raise ValueError('Ambiguous hemisphere: ML locations are all 0...')

print('laterality: {}'.format(hemi))
keys_laterality = ephys.ProbeInsertion.InsertionLocation & keys2units & laterality_condition[hemi]
unit_qc_region_mouse_session_laterality = unit_qc_region_mouse_session & keys_laterality
keys2units = gen_keys2units(unit_qc_region_mouse_session_laterality)

print(len(keys2units))

total number of units in PL by mouse HH13 by session 36: 19
num of units: 19
laterality: left
num of units: 19
19


In [14]:
period = 'iti_all'
unit_key = keys2units[0]
period_activity = psth_foraging.compute_unit_period_activity(unit_key, period)
firing_rate = period_activity['firing_rates']
print(firing_rate.shape)

(470,)


In [None]:
subject_id = 482353
model_id = 10
q_latent_variable = (foraging_model.FittedSessionModel.TrialLatentVariable 
                    & {'subject_id': subject_id, 
                       'model_id': model_id})

df_Q = pd.DataFrame(q_latent_variable.fetch())
df_Q_left = df_Q[df_Q['water_port']=='left']#.sort_values(by=['trial'])
df_Q_right = df_Q[df_Q['water_port']=='right']#.sort_values(by=['trial'])

# get only Qs columns
df_Q_right = df_Q_right[['session', 'trial', 'action_value']].rename(columns={'action_value': 'Q_right'})#.reset_index(drop=True)
df_Q_left = df_Q_left[['session', 'trial', 'action_value']].rename(columns={'action_value': 'Q_left'})#.reset_index(drop=True)
df_Qs = df_Q_left.merge(df_Q_right)