In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.spike_utils as spike_utils
from matplotlib import pyplot as plt

### Examine one session, plot

In [None]:
SESSIONS_PATH = "/data/patrick_res/multi_sess/valid_sessions_rpe.pickle"
sessions = pd.read_pickle(SESSIONS_PATH)
sessions.sort_values("session_name")

In [None]:
FEATURE_DIMS = ["Color", "Shape", "Pattern"]
OUTPUT_DIR = "/data/patrick_res/information"

SESSIONS_PATH = "/data/patrick_res/multi_sess/valid_sessions_rpe.pickle"

session = 20180802
mi = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_features_mi.pickle"))
null = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_features_null_stats.pickle"))
shuffled_mis = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_features_null_shuffled.pickle"))

mi_rpe = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_rpe_mi.pickle"))
null_rpe = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_rpe_null_stats.pickle"))
shuffled_mis_rpe = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_rpe_null_shuffled.pickle"))


In [None]:
PRE_INTERVAL = 1300
POST_INTERVAL = 1500
INTERVAL_SIZE = 100
SMOOTH = 1
EVENT = "FeedbackOnset"
fr_path = f"/data/patrick_res/multi_sess/{session}/{session}_firing_rates_{PRE_INTERVAL}_{EVENT}_{POST_INTERVAL}_{INTERVAL_SIZE}_bins_{SMOOTH}_smooth.pickle"
frs = pd.read_pickle(fr_path)
pos = spike_utils.get_unit_positions_per_sess(session)
pos = spike_utils.get_manual_structure(pos)

In [None]:
bon_feature_stats = information_utils.calc_corrected_null_stats(shuffled_mis, FEATURE_DIMS, p_val=0.05, num_hyp=15)
null = pd.merge(null, bon_feature_stats, on=["UnitID", "TimeBins"])

bon_rpe_stats = information_utils.calc_corrected_null_stats(shuffled_mis_rpe, ["RPEGroup"], p_val=0.05, num_hyp=15)
null_rpe = pd.merge(null_rpe, bon_rpe_stats, on=["UnitID", "TimeBins"])

In [None]:
mi_features_df = pd.merge(mi, null, on=["UnitID", "TimeBins"])
mi_rpe_df = pd.merge(mi_rpe, null_rpe, on=["UnitID", "TimeBins"])
mi_df = pd.merge(mi_features_df, mi_rpe_df, on=["UnitID", "TimeBins"])

In [None]:
unit_sig = information_utils.assess_significance(mi_df, FEATURE_DIMS + ["RPEGroup"])

In [None]:
sig_units = unit_sig[(unit_sig.ColorSig | unit_sig.ShapeSig | unit_sig.PatternSig) & unit_sig.RPEGroupSig].UnitID

In [None]:
len(mi_df.UnitID.unique())

In [None]:
len(sig_units)

In [None]:
# for unit in mi.UnitID.unique():
for unit in sig_units:
    fig, axs = plt.subplots(1, 4, figsize=(18, 5))
    for i, feature in enumerate(FEATURE_DIMS):
        unit_pos = pos[pos.UnitID == unit].manual_structure.unique()[0]
        unit_mi = mi[mi.UnitID == unit]
        unit_null = null[null.UnitID == unit]
        time_bins = unit_mi.TimeBins - 1.3
        mi_vals = unit_mi[f"MI{feature}"]
        null_95 = unit_null[f"MIShuffled{feature}95th"]
        # null_99 = unit_null[f"MIShuffled{feature}99th"]
        null_95_cor = unit_null[f"MIShuffled{feature}Corrected"]
        axs[i].plot(time_bins, mi_vals, label="MI")
        # axs[i].plot(time_bins, null_95, label="p < 0.05")
        # axs[i].plot(time_bins, null_99, label="p < 0.01")
        axs[i].plot(time_bins, null_95_cor, label="p < 0.05 corrected")
        axs[i].set_title(f"Unit {unit} ({unit_pos}) {feature}")
        axs[i].legend()
    unit_rpe_mi = mi_rpe[mi_rpe.UnitID == unit]
    unit_rpe_null = null_rpe[null_rpe.UnitID == unit]
    time_bins = unit_rpe_mi.TimeBins - 1.3
    mi_vals = unit_rpe_mi[f"MIRPEGroup"]
    null_95_cor = unit_rpe_null[f"MIShuffledRPEGroupCorrected"]
    axs[3].plot(time_bins, mi_vals, label="MI")
    axs[3].plot(time_bins, null_95_cor, label="p < 0.05 corrected")
    axs[3].set_title(f"Unit {unit} RPE Group")
    axs[3].legend() 

### Do this for every session to form a sub-population 

In [None]:
def both_sig(unit_sig):
    return unit_sig[(unit_sig.ColorSig | unit_sig.ShapeSig | unit_sig.PatternSig) & unit_sig.RPEGroupSig]

def feature_sig(unit_sig):
    return unit_sig[(unit_sig.ColorSig | unit_sig.ShapeSig | unit_sig.PatternSig)]

def rpe_sig(unit_sig):
    return unit_sig[unit_sig.RPEGroupSig]

def find_interesting_units_per_session(session, time_bins, sig_criteria=both_sig):
    feature_mis = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_features_mi.pickle"))
    feature_mis = feature_mis[feature_mis.TimeBins.isin(time_bins)]
    shuffled_feature_mis = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_features_null_shuffled.pickle"))
    shuffled_feature_mis = shuffled_feature_mis[shuffled_feature_mis.TimeBins.isin(time_bins)]

    rpe_mis = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_rpe_mi.pickle"))
    rpe_mis = rpe_mis[rpe_mis.TimeBins.isin(time_bins)]
    shuffled_rpe_mis = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_rpe_null_shuffled.pickle"))
    shuffled_rpe_mis = shuffled_rpe_mis[shuffled_rpe_mis.TimeBins.isin(time_bins)]

    num_hyp = len(time_bins)
    bon_feature_stats = information_utils.calc_corrected_null_stats(shuffled_feature_mis, FEATURE_DIMS, p_val=0.05, num_hyp=num_hyp)
    bon_rpe_stats = information_utils.calc_corrected_null_stats(shuffled_rpe_mis, ["RPEGroup"], p_val=0.05, num_hyp=num_hyp)

    mi_features_df = pd.merge(feature_mis, bon_feature_stats, on=["UnitID", "TimeBins"])
    mi_rpe_df = pd.merge(rpe_mis, bon_rpe_stats, on=["UnitID", "TimeBins"])
    mi_df = pd.merge(mi_features_df, mi_rpe_df, on=["UnitID", "TimeBins"])

    unit_sig = information_utils.assess_significance(mi_df, FEATURE_DIMS + ["RPEGroup"])
    sig_units = sig_criteria(unit_sig).UnitID

    pos = spike_utils.get_unit_positions_per_sess(session)
    pos = pos.fillna("unknown")
    pos = spike_utils.get_manual_structure(pos)
    filtered_pos = pos[pos.UnitID.isin(sig_units)]
    return filtered_pos

In [None]:
SESSIONS_PATH = "/data/patrick_res/multi_sess/valid_sessions_rpe.pickle"
valid_sess = pd.read_pickle(SESSIONS_PATH)
time_bins = mi[mi.TimeBins > 1.3].TimeBins.unique()
interesting_after_fb = valid_sess.apply(lambda row: find_interesting_units_per_session(row.session_name, time_bins), axis=1).values
interesting_after_fb = pd.concat(interesting_after_fb)

In [None]:
time_bins = mi[mi.TimeBins <= 1.3].TimeBins.unique()
interesting_before_fb = valid_sess.apply(lambda row: find_interesting_units_per_session(row.session_name, time_bins), axis=1).values
interesting_before_fb = pd.concat(interesting_before_fb)

In [None]:
only_before = interesting_before_fb[~interesting_before_fb.PseudoUnitID.isin(interesting_after_fb.PseudoUnitID)]

In [None]:
only_after = interesting_after_fb[~interesting_after_fb.PseudoUnitID.isin(interesting_before_fb.PseudoUnitID)]

In [None]:
both = interesting_before_fb[interesting_before_fb.PseudoUnitID.isin(interesting_after_fb.PseudoUnitID)]

In [None]:
len(both)

In [None]:
all_pos = spike_utils.get_unit_positions(valid_sess)

In [None]:
len(all_pos)

In [None]:
before_portions = interesting_before_fb.groupby("manual_structure").count()["UnitID"] / all_pos.groupby("manual_structure").count()["UnitID"]
after_portions = interesting_after_fb.groupby("manual_structure").count()["UnitID"] / all_pos.groupby("manual_structure").count()["UnitID"]
both_portions = both.groupby("manual_structure").count()["UnitID"] / all_pos.groupby("manual_structure").count()["UnitID"]

before_portions = before_portions.sort_values(ascending=False)
after_portions = after_portions.sort_values(ascending=False)
both_portions = both_portions.sort_values(ascending=False)

ax = after_portions.plot.bar(figsize=(5, 6))
ax.set_title("Proportion of neurons with both information by region (both before and after feedback)")

In [None]:
filtered_pos.to_pickle("/data/patrick_scratch/information/subpops/feature_and_rpe_units.pickle")

In [None]:
len(valid_sess)

In [None]:
just_rpe_pos = valid_sess.apply(lambda row: find_interesting_units_per_session(row.session_name, rpe_sig), axis=1).values
just_rpe_pos = pd.concat(just_rpe_pos)

In [None]:
just_feature_pos = valid_sess.apply(lambda row: find_interesting_units_per_session(row.session_name, feature_sig), axis=1).values
just_feature_pos = pd.concat(just_feature_pos)

In [None]:
len(just_rpe_pos)

In [None]:
len(just_feature_pos)

In [None]:
all_pos = spike_utils.get_unit_positions(valid_sess)

### Plot the positions of all units, units selective for both

In [None]:
fig = visualization_utils.generate_glass_brain(all_pos, "manual_structure", name_to_color=visualization_utils.REGION_TO_COLOR)
# directly write this figure as a html file. 
# ran into performance issues displaying the glass brain within the notebook
fig.write_html("/data/patrick_scratch/information/figs/units_glass_brain_rpe_sessions.html")

In [None]:
fig = visualization_utils.generate_glass_brain(filtered_pos, "manual_structure", name_to_color=visualization_utils.REGION_TO_COLOR)
# directly write this figure as a html file. 
# ran into performance issues displaying the glass brain within the notebook
fig.write_html("/data/patrick_scratch/information/figs/units_glass_brain_both_selective.html")

In [None]:
fig = visualization_utils.generate_glass_brain(interesting_before_fb, "manual_structure", name_to_color=visualization_utils.REGION_TO_COLOR)
# directly write this figure as a html file. 
# ran into performance issues displaying the glass brain within the notebook
fig.write_html("/data/patrick_scratch/information/figs/units_glass_brain_selective_before_fb.html")

In [None]:
fig = visualization_utils.generate_glass_brain(interesting_after_fb, "manual_structure", name_to_color=visualization_utils.REGION_TO_COLOR)
# directly write this figure as a html file. 
# ran into performance issues displaying the glass brain within the notebook
fig.write_html("/data/patrick_scratch/information/figs/units_glass_brain_selective_after_fb.html")

In [None]:
fig = visualization_utils.generate_glass_brain(both, "manual_structure", name_to_color=visualization_utils.REGION_TO_COLOR)
# directly write this figure as a html file. 
# ran into performance issues displaying the glass brain within the notebook
fig.write_html("/data/patrick_scratch/information/figs/units_glass_brain_selective_before_and_after_fb.html")

### Look at proportion of units by region

In [None]:
all_pos.groupby("manual_structure").count()

In [None]:
filtered_pos.groupby("manual_structure").count()

In [None]:
portions = filtered_pos.groupby("manual_structure").count()["UnitID"] / all_pos.groupby("manual_structure").count()["UnitID"]
portions = portions.sort_values(ascending=False)
ax = portions.plot.bar(figsize=(5, 6))
ax.set_title("Proportion of neurons with both information by region")

In [None]:
portions = just_rpe_pos.groupby("manual_structure").count()["UnitID"] / all_pos.groupby("manual_structure").count()["UnitID"]
portions = portions.sort_values(ascending=False)
ax = portions.plot.bar(figsize=(5, 6))
ax.set_title("Proportion of neurons with RPE group information by region")

In [None]:
portions = just_feature_pos.groupby("manual_structure").count()["UnitID"] / all_pos.groupby("manual_structure").count()["UnitID"]
portions = portions.sort_values(ascending=False)
ax = portions.plot.bar(figsize=(5, 6))
ax.set_title("Proportion of neurons with feature information by region")