In [1]:
import pathlib
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.image
from matplotlib.axes import Axes
from matplotlib.ticker import MultipleLocator
import scipy.stats
from scipy.fftpack import fft  # fourier transform
from powltools.analysis.recording import group_by_param
from mytools import (
    plot_bracket,
    anova_tukey,
    condition_batch,
    subplot_indicator,
    figure_add_axes_inch,
    figure_add_axes_group_inch,
)
from xcorr_tools import binary_spiketrain, make_psth_bins, get_peak
from xcorr_tools import wrap_to_pi

In [2]:
c_forebrain = ["orange", "orange", "orange", "orange", "orange", "orange"]
c_ot = ["b", "b", "b", "b", "b", "b"]

c_forebrain_am55 = [
    "mediumseagreen",
    "mediumseagreen",
    "mediumseagreen",
    "mediumseagreen",
    "mediumseagreen",
    "mediumseagreen",
]
c_forebrain_am75 = [
    "cornflowerblue",
    "cornflowerblue",
    "cornflowerblue",
    "cornflowerblue",
    "cornflowerblue",
    "cornflowerblue",
]

# OUTDIR = pathlib.Path("./forebrain_figures_output").absolute()
OUTDIR = pathlib.Path("./forebrain_output_resubmission").absolute()

OUTDIR.mkdir(exist_ok=True)

In [3]:
# matplotlib settings
mpl.rcParams["mathtext.default"] = "regular"
mpl.rcParams["font.family"] = "sans-serif"
mpl.rcParams["font.sans-serif"] = "Arial"

# Reproducible SVG files (including xml ids):
mpl.rcParams["svg.hashsalt"] = "123"
mpl.rcParams["savefig.dpi"] = "300"

# Font sizes:
mpl.rcParams["font.size"] = 12
mpl.rcParams["axes.labelsize"] = 12  # default: 'medium' == 10
mpl.rcParams["xtick.labelsize"] = 10  # default: 'medium' == 10
mpl.rcParams["ytick.labelsize"] = 10  # default: 'medium' == 10
mpl.rcParams["legend.fontsize"] = 10
mpl.rcParams["xtick.major.size"] = 2
mpl.rcParams["ytick.major.size"] = 2

# Load Data

In [4]:
FEATHER_DIR = pathlib.Path("./forebrain_intermediate_results")
FEATHER_DIR.absolute().mkdir(exist_ok=True)

In [5]:
## Flat Noise Competition

single_srf = pd.read_feather(FEATHER_DIR / "single_srf_dualregion.feather").set_index(
    ["date", "owl", "channel"]
)


single_rlf = pd.read_feather(FEATHER_DIR / "single_rlf_dualregion.feather").set_index(
    ["date", "owl", "channel"]
)

single_rlf_out = pd.read_feather(
    FEATHER_DIR / "single_rlf_out_dualregion.feather"
).set_index(["date", "owl", "channel"])


single_ccg = pd.read_feather(FEATHER_DIR / "single_ccg_dualregion.feather").set_index(
    ["date", "owl", "channel1", "channel2"]
)

single_ccg_out = pd.read_feather(
    FEATHER_DIR / "single_ccg_dualregion_out.feather"
).set_index(["date", "owl", "channel1", "channel2"])


single_gamma_power = pd.read_feather(
    FEATHER_DIR / "single_gamma_power_dualregion.feather"
).set_index(["date", "owl", "channel"])


within_area_sfc = pd.read_feather(
    FEATHER_DIR / "twostim_sfc_within_area_dualregion.feather"
).set_index(["date", "owl", "channel"])


twostim_rlf = pd.read_feather(FEATHER_DIR / "twostim_rlf_dualregion.feather").set_index(
    ["date", "owl", "channel"]
)

twostim_rlf_switch = pd.read_feather(
    FEATHER_DIR / "twostim_rlf_switch_dualregion.feather"
).set_index(["date", "owl", "channel"])


twostim_ccg = pd.read_feather(FEATHER_DIR / "twostim_ccg_dualregion.feather").set_index(
    ["date", "owl", "channel1", "channel2"]
)


twostim_ccg_switch = pd.read_feather(
    FEATHER_DIR / "twostim_ccg_switch_dualregion.feather"
).set_index(["date", "owl", "channel1", "channel2"])


cross_region_sfc = pd.read_feather(
    FEATHER_DIR / "sfc_cross_region_df_dualregion.feather"
).set_index(["date", "owl", "channel1", "channel2"])


twostim_gamma_power = pd.read_feather(
    FEATHER_DIR / "twostim_gamma_power_dualregion.feather"
).set_index(["date", "owl", "channel"])


## AM Competition
am_single_rlf = pd.read_feather(
    FEATHER_DIR / "am_single_rlf_dualregion.feather"
).set_index(["date", "owl", "channel"])

am_single_ccg = pd.read_feather(
    FEATHER_DIR / "am_single_ccg_dualregion.feather"
).set_index(["date", "owl", "channel1", "channel2"])
am_single_stim_phaselocking = pd.read_feather(
    FEATHER_DIR / "am_single_stim_phaselocking_dualregion.feather"
).set_index(["date", "owl", "channel"])


cross_region_gamma = pd.read_feather(
    FEATHER_DIR / "cross_region_gamma_dualregion.feather"
).set_index(["date", "owl", "channel1", "channel2"])

am_single_gamma_power = pd.read_feather(
    FEATHER_DIR / "am_single_gamma_power_dualregion.feather"
).set_index(["date", "owl", "channel"])

am_twostim_rlf = pd.read_feather(
    FEATHER_DIR / "am_twostim_rlf_dualregion.feather"
).set_index(["date", "owl", "channel"])
am_twostim_ccg = pd.read_feather(
    FEATHER_DIR / "am_twostim_ccg_dualregion.feather"
).set_index(["date", "owl", "channel1", "channel2"])
am_twostim_stim_phaselocking = pd.read_feather(
    FEATHER_DIR / "am_twostim_stim_phaselocking_dualregion.feather"
).set_index(["date", "owl", "channel"])
am_twostim_gamma_power = pd.read_feather(
    FEATHER_DIR / "am_twostim_gamma_power_dualregion.feather"
).set_index(["date", "owl", "channel"])

In [6]:
## example data
example_srf = pd.read_feather(FEATHER_DIR / "example_srf_20240315_256.feather")
example_spiketrains_forebrain_flat = pd.read_feather(
    FEATHER_DIR / "example_spiketrains_forebrain_flat_20240315_256.feather"
)

example_spiketrains_am55 = pd.read_feather(
    FEATHER_DIR / "example_spiketrains_forebrain_driver55_20240330_40.feather"
)
example_spiketrains_am75 = pd.read_feather(
    FEATHER_DIR / "example_spiketrains_forebrain_driver75_20240330_40.feather"
)


lowcorr_srf_ot = pd.read_feather(
    FEATHER_DIR / "example_srf_lowcorr_OT_20240420_54.feather"
)
lowcorr_srf_fb = pd.read_feather(
    FEATHER_DIR / "example_srf_lowcorr_Forebrain_20240420_54.feather"
)

midlowcorr_srf_ot = pd.read_feather(
    FEATHER_DIR / "example_srf_midlowcorr_OT_20240423_54.feather"
)
midlowcorr_srf_fb = pd.read_feather(
    FEATHER_DIR / "example_srf_midlowcorr_Forebrain_20240423_54.feather"
)

midhighcorr_srf_ot = pd.read_feather(
    FEATHER_DIR / "example_srf_midhighcorr_OT_20240420_54.feather"
)
midhighcorr_srf_fb = pd.read_feather(
    FEATHER_DIR / "example_srf_midhighcorr_Forebrain_20240420_54.feather"
)

highcorr_srf_ot = pd.read_feather(
    FEATHER_DIR / "example_srf_highcorr_OT_20240420_54.feather"
)
highcorr_srf_fb = pd.read_feather(
    FEATHER_DIR / "example_srf_highcorr_Forebrain_20240420_54.feather"
)

# Functions

In [7]:
def confidence_intervale(sample_data):
    # Sample statistics
    sample_mean = np.mean(sample_data)
    sample_std = np.std(sample_data, ddof=1)
    n = len(sample_data)
    confidence_level = 0.95

    # Calculate standard error (population standard deviation unknown)
    standard_error = sample_std / np.sqrt(n)

    # Degrees of freedom for t-distribution
    df = n - 1

    # Calculate confidence interval using t.interval()
    confidence_interval = scipy.stats.t.interval(
        confidence_level, df, loc=sample_mean, scale=standard_error
    )

    print(f"Confidence Interval: {confidence_interval}")
    return confidence_interval

In [8]:
def phase_properties(phases):
    ## convert to complex numbers
    pi_bin = np.linspace(-np.pi, np.pi, 9)
    #    pi_bin = np.arange(-np.pi, np.pi + np.pi / 6, np.pi / 6)
    sfc = np.histogram(phases, pi_bin)
    bins = sfc[1]

    ##calculate mean angle of frequency grouped data
    y_angle = sum(sfc[0] * np.sin(sfc[1][:-1])) / sum(sfc[0])
    x_angle = sum(sfc[0] * np.cos(sfc[1][:-1])) / sum(sfc[0])
    # r_val = np.sqrt((y_angle**2) + (x_angle**2))
    mean_ang = np.arctan2((y_angle), (x_angle))

    r_val = np.sqrt((y_angle**2) + (x_angle**2))
    d = np.diff(bins)[0]
    corr_fac = d / 2 / np.sin(d / 2)
    r_val = corr_fac * r_val

    total_n = len(phases)
    rayleigh_r = total_n * r_val
    pval = np.exp(
        np.sqrt(1 + 4 * total_n + 4 * (total_n**2 - rayleigh_r**2)) - (1 + 2 * total_n)
    )
    return {"vector_strength": r_val, "mean_angle": mean_ang, "pval": pval}

# Plotting + Analysis Functions

## 2.1 example spatial receptive field

In [9]:
def figure_spatial_receptive_field(
    srf_data: pd.DataFrame,
    # rec: Recording,
    # channel_number: int,
    ax=None,
    cax=None,
):
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure

    ax.xaxis.set_major_locator(MultipleLocator(30))
    ax.set_xlabel("Azimuth [deg]")
    ax.yaxis.set_major_locator(MultipleLocator(30))
    ax.set_ylabel("Elevation [deg]")

    azimuths = srf_data["azimuth"].values
    elevations = srf_data["elevation"].values
    response_rates = srf_data["response_mean"].values

    cmap = mpl.colormaps["Greys"]
    srf = ax.scatter(
        x=azimuths,
        y=elevations,
        c=response_rates,
        cmap=cmap,
        edgecolors="0.7",
        vmin=0.0,
        linewidths=0.5,
    )
    cb = fig.colorbar(srf, cax=cax)
    cb.set_label("Response [spks/stim]")
    return

## Azimuth tuning heatmap + population average

In [10]:
def azimuth_heatplot_alt(single_srf_df, axs=None, caxs=None):

    if axs is None:
        fig, axs = plt.subplots(1, 2)

    srf_df = single_srf_df.copy()
    rlf_df = single_rlf_df.copy()

    itd_range = np.linspace(-100, 100, 21)
    itd_response_hm = np.empty([len(np.unique(srf_df.index)), len(itd_range)])

    for krow, ind in enumerate(np.unique(srf_df.index)):
        unit_df = srf_df.loc[ind]
        means = unit_df.groupby("azimuth")["norm_resp"].max()

        azimuths = means.index
        spkresponse = means.values

        if unit_df.iloc[0]["hemisphere"] == "right":
            itd_response_hm[krow, :] = spkresponse[::-1]
        else:
            itd_response_hm[krow, :] = spkresponse
    print(f"nRT units {itd_response_hm.shape[0]}")
    cmap_reversed = matplotlib.cm.get_cmap("hot")
    plot_xticks = np.linspace(-100, 100, 5)
    sc = axs[0].imshow(
        itd_response_hm,
        extent=[-100, 100, 0, itd_response_hm.shape[0]],
        aspect="auto",
        cmap=cmap_reversed,
        vmin=0,
        vmax=1.0,
    )
    axs[0].set_xticks(
        [int(val) for val in plot_xticks], [int(val) for val in plot_xticks]
    )
    axs[0].set_ylabel("Units")
    axs[0].set_xlabel("Azimuth")
    cb = plt.colorbar(sc, cax=caxs)
    cb.set_label("Norm. Response")

    mean_azimuth_curve = np.mean(itd_response_hm, axis=0)
    sem = scipy.stats.sem(itd_response_hm, axis=0)

    axs[1].errorbar(itd_range, mean_azimuth_curve, sem, color="k")
    axs[1].set_ylim(bottom=0.30, top=0.75)
    axs[1].set_xlabel("Azimuth")
    axs[1].set_ylabel("Norm. Response")
    axs[1].spines["top"].set_visible(False)
    axs[1].spines["right"].set_visible(False)

## Elevation tuning heatmap and population average

In [11]:
def elevation_heatplot_alt(single_srf_df, axs=None, caxs=None):

    if axs is None:
        fig, axs = plt.subplots(1, 2)

    srf_df = single_srf_df.copy()

    ele_range = np.linspace(-80, 80, 17)
    ele_response_hm = np.empty([len(np.unique(srf_df.index)), len(ele_range)])

    for krow, ind in enumerate(np.unique(srf_df.index)):
        unit_df = srf_df.loc[ind]
        means = unit_df.groupby("elevation")["norm_resp"].max()

        elevations = means.index
        spkresponse = means.values

        ele_response_hm[krow, :] = spkresponse

    cmap_reversed = matplotlib.cm.get_cmap("hot")

    plot_xticks = np.linspace(-80, 80, 5)
    sc = axs[0].imshow(
        ele_response_hm,
        extent=[-80, 80, 0, ele_response_hm.shape[0]],
        aspect="auto",
        cmap=cmap_reversed,
        vmin=0,
        vmax=1.0,
    )
    axs[0].set_xticks(
        [int(val) for val in plot_xticks], [int(val) for val in plot_xticks]
    )
    axs[0].set_ylabel("Units")
    axs[0].set_xlabel("Elevation")
    cb = plt.colorbar(sc, cax=caxs)
    cb.set_label("Norm. Response")

    mean_ele_curve = np.mean(ele_response_hm, axis=0)
    sem = scipy.stats.sem(ele_response_hm, axis=0)

    axs[1].errorbar(ele_range, mean_ele_curve, sem, color="k")
    axs[1].set_xlabel("Elevation")
    axs[1].set_ylabel("Norm. Response")
    axs[1].set_ylim(bottom=0.30, top=0.75)
    axs[1].spines["top"].set_visible(False)
    axs[1].spines["right"].set_visible(False)

## competition boxplot

In [12]:
def figure_competition_boxplot(
    twostim_rlf: pd.DataFrame,
    single_rlf: pd.DataFrame,
    ax=None,
    colors=c_forebrain,
    brackets: dict[tuple[int, int] : float] | None = None,
    anova_align="right",
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    if ax is None:
        fig, ax = plt.subplots()

    merged_rlf = twostim_rlf.join(single_rlf, how="inner", rsuffix="_single")
    merged_rlf = merged_rlf.loc[
        (merged_rlf["fixedintensity"] == merged_rlf["intensity"])
    ]
    merged_rlf = merged_rlf.loc[(merged_rlf["fixedazi"] == merged_rlf["azimuth"])]
    merged_rlf = merged_rlf.loc[(merged_rlf["fixedele"] == merged_rlf["elevation"])]

    print("Number of units:", np.unique(merged_rlf.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_rlf.index.values]))),
    )
    merged_rlf.set_index("relative_level", inplace=True)
    # display(merged_rlf_ot)

    merged_rlf["change_in_response"] = (
        merged_rlf["resp"] - merged_rlf["resp_single"]
    ) / merged_rlf["resp_single"]
    relative_levels = np.unique(merged_rlf.index)
    groupdata = [
        merged_rlf.loc[relative_level, "change_in_response"].values
        for relative_level in relative_levels
    ]
    for k, lev in enumerate(relative_levels):
        print(
            f"{lev}: {np.median(groupdata[k])} , sem: {scipy.stats.sem(groupdata[k])}"
        )

    bp = ax.boxplot(
        groupdata,
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
    )

    ax.set_xticklabels(relative_levels.astype(int))

    stats = anova_tukey(
        merged_rlf, val_col="change_in_response", group_col="relative_level"
    )

    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)

    for median in bp["medians"]:
        median.set(color="white", linewidth=2)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate(groupdata):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker=".",
                markersize=2,
                markeredgewidth=0.0,
                zorder=5,
            )

    if True:  # show_brackets
        brackets_shrink = 1.0
        if brackets is None:
            brackets = {(k, k + 1): 0.8 for k in range(len(stats["tukey"]) - 1)}
            brackets_shrink = 0.8
        for (idx_left, idx_right), y in brackets.items():
            plot_bracket(
                ax,
                left=idx_left + 1,
                right=idx_right + 1,
                text=f"{stats['tukey'].iloc[idx_left, idx_right]:.3f}"[1:],
                y=y,
                shrink=brackets_shrink,
            )

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    if anova_align == "right":
        ax.text(
            1,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="right",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes,
        )
    else:
        ax.text(
            0,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="left",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes
            + mpl.transforms.ScaledTranslation(
                +2 / 72,
                0 / 72,
                ax.figure.dpi_scale_trans,
            ),
        )

    ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: f"{x:.0%}"))

    ax.axhline(0, color="k", ls="--", zorder=-1)
    ax.set_xlabel("Relative Level [dB]")
    ax.set_ylabel("Change in\nSpike Rate")
    ax.set_ylim(-1, 1.5)

## Raster Plot

In [13]:
def figure_coincident_rasterplot(
    spiketrain_data: pd.DataFrame,
    # filename: str,
    # chan1: int,
    # chan2: int,
    relative_levels=[-15, +10],
    axs=None,
):

    if axs is None:

        fig, ax = plt.subplots(1, len(relative_levels))
    axs = axs.flatten()

    spiketrains_unit1 = spiketrain_data["spiketrain_unit1"].values
    spiketrains_unit2 = spiketrain_data["spiketrain_unit2"].values
    trial_relative_levels = spiketrain_data["relative_level"]

    stim_spikes1 = group_by_param(
        np.array(
            spiketrains_unit1,
            dtype="object",
        ),
        trial_relative_levels,
    )
    stim_spikes2 = group_by_param(
        np.array(
            spiketrains_unit2,
            dtype="object",
        ),
        trial_relative_levels,
    )

    psth_bins = make_psth_bins(
        stimdelay=1.0,
        stimduration=1.0,
        binsize=0.001,
        offset=0.000,
    )

    for k, relative_level in enumerate(relative_levels):
        ax = axs[k]
        spiketrains1 = stim_spikes1[relative_level][:10]
        spiketrains2 = stim_spikes2[relative_level][:10]
        coincidences = [
            np.where(
                binary_spiketrain(st1, bins=psth_bins)
                & binary_spiketrain(st2, bins=psth_bins)
            )[0]
            * 0.001
            + psth_bins[0]
            for st1, st2 in zip(spiketrains1, spiketrains2)
        ]
        ax.eventplot(
            spiketrains1,
            lineoffsets=4 * np.arange(spiketrains1.size),
            linelengths=1,
            color="#00dd33",
            alpha=0.3,
        )
        ax.eventplot(
            spiketrains2,
            lineoffsets=2 + 4 * np.arange(spiketrains2.size),
            linelengths=1,
            color="#ffaa00",
            alpha=0.3,
        )
        ax.eventplot(
            coincidences,
            lineoffsets=1 + 4 * np.arange(len(coincidences)),
            linelengths=1.4,
            color="k",
        )
        ax.axvspan(1.0, 1.050, color="#ddf", zorder=-50)
        ax.set_xlim(left=1.0, right=2.0)
        ax.set_xticks([1.0, 1.5, 2.0], ["0", "0.5", "1"], fontsize=8)
        ax.set_ylim(bottom=-0.5, top=2 + 4 * (spiketrains2.size - 1) + 0.5)

    for ax in axs[1:]:
        ax.set_yticks([])

    axs[0].set_xlabel("Time [s]", fontsize=8)
    ytick_trial = 1
    axs[0].set_yticks(
        [4 * ytick_trial, 2 + 4 * ytick_trial, 1 + 4 * (ytick_trial + 2)],
        ["", "", ""],
    )
    axs[0].tick_params(axis="y", length=4)
    axs[0].text(
        0,
        4 * ytick_trial,
        "Unit A",
        fontsize=10,
        ha="right",
        va="top",
        color="g",
        transform=mpl.transforms.blended_transform_factory(
            axs[0].transAxes, axs[0].transData
        )
        + mpl.transforms.ScaledTranslation(
            -4 / 72,
            0,
            axs[0].figure.dpi_scale_trans,
        ),
    )

    axs[0].text(
        0,
        2 + 4 * ytick_trial,
        "Unit B",
        fontsize=10,
        ha="right",
        va="bottom",
        color="orange",
        transform=mpl.transforms.blended_transform_factory(
            axs[0].transAxes, axs[0].transData
        )
        + mpl.transforms.ScaledTranslation(
            -4 / 72,
            0,
            axs[0].figure.dpi_scale_trans,
        ),
    )

    axs[0].text(
        0,
        1 + 4 * (ytick_trial + 2),
        "Coinc.",
        fontsize=10,
        ha="right",
        va="bottom",
        color="k",
        transform=mpl.transforms.blended_transform_factory(
            axs[0].transAxes, axs[0].transData
        )
        + mpl.transforms.ScaledTranslation(
            -4 / 72,
            0,
            axs[0].figure.dpi_scale_trans,
        ),
    )

## CCGs

In [14]:
def figure_xcorr_ccg(
    twostim_ccg: pd.DataFrame, single_ccg: pd.DataFrame, axs=None, colors=c_forebrain
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    if axs is None:
        fig, axs = plt.subplots(1, 6)
    axs = axs.flatten()

    merged_ccg = twostim_ccg.join(single_ccg, how="inner", rsuffix="_single")
    merged_ccg = merged_ccg.loc[
        (merged_ccg["fixedintensity"] == merged_ccg["intensity"])
    ]
    merged_ccg = merged_ccg.loc[(merged_ccg["fixedazi"] == merged_ccg["azimuth"])]
    merged_ccg = merged_ccg.loc[(merged_ccg["fixedele"] == merged_ccg["elevation"])]

    print("Number of unit pairs:", np.unique(merged_ccg.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_ccg.index.values]))),
    )

    # Before reindexing by relative_level:
    single_mean_peak = np.max(
        np.mean(
            np.vstack(merged_ccg["ccg_single"].groupby(merged_ccg.index.names).first()),
            axis=0,
        )
    )
    single_std_peak = np.std(
        np.max(
            np.vstack(merged_ccg["ccg_single"].groupby(merged_ccg.index.names).first()),
            axis=0,
        )
    )
    print(f"mean: {single_mean_peak = :.4g} std: {single_std_peak}")

    merged_ccg.set_index("relative_level", inplace=True)

    psth_len = (merged_ccg.iloc[0]["ccg_single"].size + 1) / 2
    lags = scipy.signal.correlation_lags(psth_len, psth_len) / 1000
    lags_mask = np.abs(lags) <= 0.05

    relative_levels = np.unique(merged_ccg.index)

    for k, relative_level in enumerate(relative_levels):
        mean_ccg = np.mean(np.vstack(merged_ccg.loc[relative_level, "ccg"]), axis=0)
        n_ccg = merged_ccg.loc[relative_level, "xcorr_peak_single"].count()

        axs[k].axhline(single_mean_peak, lw=1, color=colors[0])
        print(
            f"{relative_level:5} | {np.mean(merged_ccg.loc[relative_level, 'xcorr_peak_single']):.4g} {n_ccg = }"
            f"| {np.mean(merged_ccg.loc[relative_level, 'xcorr_peak']):.4g} {n_ccg = }"
            f"| peak curve: {np.max(mean_ccg)}"
        )
        axs[k].plot(lags[lags_mask], mean_ccg[lags_mask], color="k", ls="-", lw=1)
        axs[k].text(
            0.5,
            1,
            f"{n_ccg}",
            ha="center",
            va="top",
            fontsize=8,
            transform=axs[k].transAxes,
        )
        axs[k].spines["top"].set_visible(False)
        axs[k].spines["right"].set_visible(False)

        axs[k].set_ylim(bottom=0, top=0.000075)
        axs[k].set_xlim(left=-0.060, right=0.060)

    axs[0].set_ylabel("Coinc./spk", labelpad=12)

    def formatter_y(x, pos):
        if x > 0:
            return f"${{{x*1e5:.0f}}}$"
            # return f"$10^{{{np.log10(x):.0f}}}$"
        else:
            return "0"

    axs[0].set_yticks([0, 5e-5])
    axs[0].yaxis.set_major_formatter(mpl.ticker.FuncFormatter(formatter_y))
    axs[0].set_yticks(np.array([1, 2, 3, 4, 6, 7]) * 1e-5, minor=True)
    axs[0].text(
        0,
        1,
        "$10^{-5}$",
        ha="right",
        va="center",
        fontsize=8,
        transform=axs[0].transAxes
        + mpl.transforms.ScaledTranslation(
            -2 / 72,
            0,
            axs[0].figure.dpi_scale_trans,
        ),
    )

    for ax in axs[1:]:
        ax.set_yticks([])
        ax.set_yticks([], minor=True)
        ax.spines["left"].set_visible(False)

    axs[0].set_xticks([-0.05, 0, +0.05], ["-50", "0", "+50"], fontsize=8)
    axs[0].set_xlabel("Lags [ms]", fontsize=8, labelpad=0)
    for ax in axs[1:]:
        ax.set_xticks([-0.05, 0, +0.05], [])

## CCG and power spectra

In [15]:
def power_spectrum(signal, samplerate):
    n = len(signal)
    AudioFreq = fft(signal)
    AudioFreq = AudioFreq[0 : int(np.ceil((n + 1) / 2.0))]  # Half of the spectrum
    MagFreq = np.abs(AudioFreq)  # Magnitude
    MagFreq = MagFreq / float(n)
    # power spectrum
    MagFreq = MagFreq**2
    if n % 2 > 0:  # ffte odd
        MagFreq[1 : len(MagFreq)] = MagFreq[1 : len(MagFreq)] * 2
    else:  # fft even
        MagFreq[1 : len(MagFreq) - 1] = MagFreq[1 : len(MagFreq) - 1] * 2

    freqAxis = np.arange(0, int(np.ceil((n + 1) / 2.0)), 1.0) * (samplerate / n)
    return freqAxis, MagFreq


def figure_xcorr_pspectra(
    twostim_ccg: pd.DataFrame, single_ccg: pd.DataFrame, axs=None, colors=c_forebrain
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    if axs is None:
        fig, axs = plt.subplots(1, 6)
    axs = axs.flatten()

    merged_ccg = twostim_ccg.join(single_ccg, how="inner", rsuffix="_single")
    merged_ccg = merged_ccg.loc[
        (merged_ccg["fixedintensity"] == merged_ccg["intensity"])
    ]
    merged_ccg = merged_ccg.loc[(merged_ccg["fixedazi"] == merged_ccg["azimuth"])]
    merged_ccg = merged_ccg.loc[(merged_ccg["fixedele"] == merged_ccg["elevation"])]

    print("Number of unit pairs:", np.unique(merged_ccg.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_ccg.index.values]))),
    )

    # Before reindexing by relative_level:

    merged_ccg.set_index("relative_level", inplace=True)

    psth_len = (merged_ccg.iloc[0]["ccg_single"].size + 1) / 2
    lags = scipy.signal.correlation_lags(psth_len, psth_len) / 1000
    lags_mask = np.abs(lags) <= 0.05

    relative_levels = np.unique(merged_ccg.index)

    for k, relative_level in enumerate(relative_levels):
        mean_ccg = np.mean(np.vstack(merged_ccg.loc[relative_level, "ccg"]), axis=0)
        n_ccg = merged_ccg.loc[relative_level, "xcorr_peak_single"].count()

        avg_psd = []
        for trace in np.vstack(merged_ccg.loc[relative_level, "ccg"]):
            freq, psd = power_spectrum(
                trace[(lags > -0.200) & (lags < 0.200)]
                - np.mean(trace[(lags > -0.200) & (lags < 0.200)]),
                1000,
            )
            avg_psd.append(psd)

        norm_psd = [trace - np.mean(trace) for trace in avg_psd]
        avg_psd = np.mean(norm_psd, axis=0)

        ind_correct = np.searchsorted(freq, 50)
        ind_max = np.argmax(avg_psd[(freq > 50) & (freq <= 80)])
        freq_max = freq[ind_max + ind_correct]

        axs[k].plot(freq, avg_psd, color="k", ls="-", lw=1)

        if relative_level >= 0:
            axs[k].plot(
                freq_max, avg_psd[ind_max] + 2.5e-13, "v", color=colors[0], markersize=4
            )

        axs[k].spines["top"].set_visible(False)
        axs[k].spines["right"].set_visible(False)

        axs[k].set_ylim(bottom=0, top=7.5e-13)
        axs[k].set_xlim(left=20, right=100)

    axs[0].set_ylabel("Power", labelpad=12)

    def formatter_y(x, pos):
        if x > 0:
            return f"${{{x*1e13:.0f}}}$"
            # return f"$10^{{{np.log10(x):.0f}}}$"
        else:
            return "0"

    axs[0].set_yticks([0, 6.5e-13])
    axs[0].yaxis.set_major_formatter(mpl.ticker.FuncFormatter(formatter_y))
    axs[0].set_yticks(np.array([0, 3.5]) * 1e-13, minor=True)
    axs[0].text(
        0,
        1,
        "$10^{-13}$",
        ha="right",
        va="center",
        fontsize=8,
        transform=axs[0].transAxes
        + mpl.transforms.ScaledTranslation(
            -2 / 72,
            0,
            axs[0].figure.dpi_scale_trans,
        ),
    )

    for ax in axs[1:]:
        ax.set_yticks([])
        ax.set_yticks([], minor=True)
        ax.spines["left"].set_visible(False)

    axs[0].set_xticks([50, 75], ["50", "75"], fontsize=8, minor=False)
    axs[0].set_xticks([50, 75], [], fontsize=8, minor=True)

    axs[0].set_xlabel("Freq. [Hz]", fontsize=8, labelpad=0)
    for ax in axs[1:]:
        ax.set_xticks([50, 75], ["50", "75"], fontsize=8, minor=False)

        ax.set_xticks([50, 75], ["50", "75"], minor=True)

## CSI boxplot

In [16]:
def figure_xcorr_boxplot(
    twostim_rlf: pd.DataFrame,
    single_rlf: pd.DataFrame,
    ax=None,
    colors=c_forebrain,
    brackets: dict[tuple[int, int] : float] | None = None,
    anova_align="right",
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    if ax is None:
        fig, ax = plt.subplots()

    merged_rlf = twostim_rlf.join(single_rlf, how="inner", rsuffix="_single")
    merged_rlf = merged_rlf.loc[
        (merged_rlf["fixedintensity"] == merged_rlf["intensity"])
    ]
    merged_rlf = merged_rlf.loc[(merged_rlf["fixedazi"] == merged_rlf["azimuth"])]
    merged_rlf = merged_rlf.loc[(merged_rlf["fixedele"] == merged_rlf["elevation"])]

    print("Number of units:", np.unique(merged_rlf.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_rlf.index.values]))),
    )
    merged_rlf.set_index("relative_level", inplace=True)
    # display(merged_rlf_ot)

    merged_rlf["csi"] = (merged_rlf["xcorr_peak"] - merged_rlf["xcorr_peak_single"]) / (
        merged_rlf["xcorr_peak"] + merged_rlf["xcorr_peak_single"]
    )
    relative_levels = np.unique(merged_rlf.index)
    groupdata = [
        merged_rlf.loc[relative_level, "csi"].values
        for relative_level in relative_levels
    ]

    bp = ax.boxplot(
        groupdata,
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
    )

    ax.set_xticklabels(relative_levels.astype(int))

    stats = anova_tukey(merged_rlf, val_col="csi", group_col="relative_level")

    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)

    for median in bp["medians"]:
        median.set(color="white", linewidth=2)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate(groupdata):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker=".",
                markersize=2,
                markeredgewidth=0.0,
                zorder=5,
            )

    if True:  # show_brackets
        brackets_shrink = 1.0
        if brackets is None:
            brackets = {(k, k + 1): 0.8 for k in range(len(stats["tukey"]) - 1)}
            brackets_shrink = 0.8
        for (idx_left, idx_right), y in brackets.items():
            plot_bracket(
                ax,
                left=idx_left + 1,
                right=idx_right + 1,
                text=f"{stats['tukey'].iloc[idx_left, idx_right]:.3f}"[1:],
                y=y,
                shrink=brackets_shrink,
            )

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    if anova_align == "right":
        ax.text(
            1,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="right",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes,
        )
    else:
        ax.text(
            0,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="left",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes
            + mpl.transforms.ScaledTranslation(
                +2 / 72,
                0 / 72,
                ax.figure.dpi_scale_trans,
            ),
        )

    # ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: f"{x:.0%}"))

    ax.axhline(0, color="k", ls="--", zorder=-1)
    ax.set_xlabel("Relative Level [dB]")
    ax.set_ylabel("CSI")
    ax.set_ylim(-1, 1.45)

## histogram of tuning selectivity

In [17]:
def histogram_selectivity_resubmit(
    twostim_rlf: pd.DataFrame,
    single_rlf: pd.DataFrame,
    single_rlf_out: pd.DataFrame,
    ax=None,
    colors=c_forebrain,
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    if ax is None:
        fig, ax = plt.subplots(1, 1)

    ## merge single rlf data and calculate selectivity index
    single_in = single_rlf.copy()
    single_out = single_rlf_out.copy()

    frontal_driver = single_in[abs(single_in["azimuth"]) <= 45]
    peripheral_driver = single_in[abs(single_in["azimuth"]) > 45]

    frontal_competitor = single_out[abs(single_out["azimuth"]) <= 45]
    peripheral_competitor = single_out[abs(single_out["azimuth"]) > 45]

    merged_rlf1 = frontal_driver.join(
        peripheral_competitor, how="inner", rsuffix="_out"
    )
    merged_rlf1 = merged_rlf1.loc[
        merged_rlf1["intensity"] == merged_rlf1["intensity_out"]
    ]
    merged_rlf1["selectivity_index"] = (
        merged_rlf1["resp"] - merged_rlf1["resp_out"]
    ) / (merged_rlf1["resp"] + merged_rlf1["resp_out"])

    merged_rlf2 = frontal_competitor.join(
        peripheral_driver, how="inner", rsuffix="_out"
    )
    merged_rlf2 = merged_rlf2.loc[
        merged_rlf2["intensity"] == merged_rlf2["intensity_out"]
    ]
    merged_rlf2["selectivity_index"] = (
        merged_rlf2["resp"] - merged_rlf2["resp_out"]
    ) / (merged_rlf2["resp"] + merged_rlf2["resp_out"])

    tuning_index_1 = merged_rlf1.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "resp",
            "selectivity_index",
        ],
        axis=1,
    )
    tuning_index_2 = merged_rlf2.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "resp",
            "selectivity_index",
        ],
        axis=1,
    )

    tuning_index = pd.concat([tuning_index_1, tuning_index_2])

    merged_1 = twostim_rlf.join(tuning_index_1, how="inner", rsuffix="_single")
    merged_1 = merged_1.loc[(merged_1["fixedintensity"] == merged_1["intensity"])]
    merged_1 = merged_1.loc[(merged_1["fixedazi"] == merged_1["azimuth"])]
    merged_1 = merged_1.loc[(merged_1["fixedele"] == merged_1["elevation"])]

    merged_1 = merged_1.loc[(merged_1["varyingazi"] == merged_1["azimuth_out"])]
    merged_1 = merged_1.loc[(merged_1["varyingele"] == merged_1["elevation_out"])]

    merged_2 = twostim_rlf.join(tuning_index_2, how="inner", rsuffix="_single")
    merged_2 = merged_2.loc[(merged_2["fixedintensity"] == merged_2["intensity"])]

    merged_2 = merged_2.loc[(merged_2["fixedazi"] == merged_2["azimuth_out"])]
    merged_2 = merged_2.loc[(merged_2["fixedele"] == merged_2["elevation_out"])]

    merged_2 = merged_2.loc[(merged_2["varyingazi"] == merged_2["azimuth"])]
    merged_2 = merged_2.loc[(merged_2["varyingele"] == merged_2["elevation"])]

    merged_rlf = pd.concat([merged_1, merged_2])

    merged_rlf = merged_rlf[merged_rlf["relative_level"] == 0]

    quantiles = np.quantile(merged_rlf["selectivity_index"], [0.25, 0.5, 0.75])
    # print(quantiles[1])

    print("Number of units:", np.unique(merged_rlf.index).size)
    print(
        "Number of frontal driver units:",
        np.unique(
            merged_rlf[merged_rlf["selectivity_index"] >= quantiles[2]].index
        ).size,
    )
    print(
        "Number of peripheral driver units:",
        np.unique(
            merged_rlf[merged_rlf["selectivity_index"] <= quantiles[0]].index
        ).size,
    )
    print(
        "Number of intermediate units:",
        np.unique(
            merged_rlf[
                (merged_rlf["selectivity_index"] > quantiles[0])
                & (merged_rlf["selectivity_index"] < quantiles[2])
            ].index
        ).size,
    )

    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_rlf.index.values]))),
    )

    tuning_bins = np.linspace(-1, 1, 100)
    tuning_index_plot = ax.hist(
        merged_rlf["selectivity_index"], tuning_bins, color=colors[0]
    )
    mean_SI = np.mean(merged_rlf["selectivity_index"])
    std_SI = np.std(merged_rlf["selectivity_index"])

    print(f"MEAN:{mean_SI:.2f}")
    print(f"STD: {std_SI:.2f}")

    ax.axvline(0, color="k", ls=":")

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.legend(loc="upper left", frameon=False)

    ax.set_xlabel("Tuning Selectivity Index")
    ax.set_ylabel("# Units")
    ax.set_ylim(0, 20)

    ax.text(
        0.8,
        1,
        "Frontal Driver Selective",
        horizontalalignment="right",
        verticalalignment="top",
        fontsize=10,
        transform=ax.transAxes,
    )
    ax.text(
        0.45,
        1,
        "Lateral Driver Selective",
        horizontalalignment="right",
        verticalalignment="top",
        fontsize=10,
        transform=ax.transAxes,
    )

## competition by selectivity index line plot

In [18]:
def figure_rlf_selectivity_resubmit(
    twostim_rlf: pd.DataFrame,
    single_rlf: pd.DataFrame,
    single_rlf_out: pd.DataFrame,
    ax=None,
    colors=c_forebrain,
    brackets: dict[tuple[int, int] : float] | None = None,
    anova_align="right",
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    if ax is None:
        fig, ax = plt.subplots()

    ## merge single rlf data and calculate selectivity index
    single_in = single_rlf.copy()
    single_out = single_rlf_out.copy()

    frontal_driver = single_in[abs(single_in["azimuth"]) <= 45]
    peripheral_driver = single_in[abs(single_in["azimuth"]) > 45]

    frontal_competitor = single_out[abs(single_out["azimuth"]) <= 45]
    peripheral_competitor = single_out[abs(single_out["azimuth"]) > 45]

    merged_rlf1 = frontal_driver.join(
        peripheral_competitor, how="inner", rsuffix="_out"
    )
    merged_rlf1 = merged_rlf1.loc[
        merged_rlf1["intensity"] == merged_rlf1["intensity_out"]
    ]
    merged_rlf1["selectivity_index"] = (
        merged_rlf1["resp"] - merged_rlf1["resp_out"]
    ) / (merged_rlf1["resp"] + merged_rlf1["resp_out"])

    merged_rlf2 = frontal_competitor.join(
        peripheral_driver, how="inner", rsuffix="_out"
    )
    merged_rlf2 = merged_rlf2.loc[
        merged_rlf2["intensity"] == merged_rlf2["intensity_out"]
    ]
    merged_rlf2["selectivity_index"] = (
        merged_rlf2["resp"] - merged_rlf2["resp_out"]
    ) / (merged_rlf2["resp"] + merged_rlf2["resp_out"])

    tuning_index_1 = merged_rlf1.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "resp",
            "selectivity_index",
        ],
        axis=1,
    )
    tuning_index_2 = merged_rlf2.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "resp_out",
            "selectivity_index",
        ],
        axis=1,
    )

    tuning_index_2 = tuning_index_2.rename(columns={"resp_out": "resp"})

    merged_1 = twostim_rlf.join(tuning_index_1, how="inner", rsuffix="_single")
    merged_1 = merged_1.loc[(merged_1["fixedintensity"] == merged_1["intensity"])]
    merged_1 = merged_1.loc[(merged_1["fixedazi"] == merged_1["azimuth"])]
    merged_1 = merged_1.loc[(merged_1["fixedele"] == merged_1["elevation"])]

    merged_1 = merged_1.loc[(merged_1["varyingazi"] == merged_1["azimuth_out"])]
    merged_1 = merged_1.loc[(merged_1["varyingele"] == merged_1["elevation_out"])]

    merged_2 = twostim_rlf.join(tuning_index_2, how="inner", rsuffix="_single")
    merged_2 = merged_2.loc[(merged_2["fixedintensity"] == merged_2["intensity"])]

    merged_2 = merged_2.loc[(merged_2["fixedazi"] == merged_2["azimuth_out"])]
    merged_2 = merged_2.loc[(merged_2["fixedele"] == merged_2["elevation_out"])]
    merged_2 = merged_2.loc[(merged_2["varyingazi"] == merged_2["azimuth"])]
    merged_2 = merged_2.loc[(merged_2["varyingele"] == merged_2["elevation"])]
    merged_rlf = pd.concat([merged_1, merged_2])

    quantiles = np.quantile(merged_rlf["selectivity_index"], [0.25, 0.5, 0.75])

    print(
        "Number of frontal driver units + frontal driver:",
        np.unique(
            merged_rlf[
                (abs(merged_rlf["fixedazi"]) <= 45)
                & (abs(merged_rlf["varyingazi"]) > 45)
                & (merged_rlf["selectivity_index"] > 0)
            ].index
        ).size,
    )
    print(
        "Number of frontal driver units + lateral driver:",
        np.unique(
            merged_rlf[
                (abs(merged_rlf["fixedazi"]) > 45)
                & (abs(merged_rlf["varyingazi"]) <= 45)
                & (merged_rlf["selectivity_index"] > 0)
            ].index
        ).size,
    )

    ##lateral driver selective units
    print(
        "Number of lateral driver units + frontal driver:",
        np.unique(
            merged_rlf[
                (abs(merged_rlf["fixedazi"]) <= 45)
                & (abs(merged_rlf["varyingazi"]) > 45)
                & (merged_rlf["selectivity_index"] <= 0)
            ].index
        ).size,
    )
    print(
        "Number of lateral driver units + lateral driver:",
        np.unique(
            merged_rlf[
                (abs(merged_rlf["fixedazi"]) > 45)
                & (abs(merged_rlf["varyingazi"]) <= 45)
                & (merged_rlf["selectivity_index"] <= 0)
            ].index
        ).size,
    )

    merged_rlf["change_in_response"] = (
        merged_rlf["resp"] - merged_rlf["resp_single"]
    ) / merged_rlf["resp_single"]

    merged_rlf.set_index("relative_level", inplace=True)

    ##sort data based on frontal driver + lateral competitor
    frontal_driver = merged_rlf[
        (abs(merged_rlf["fixedazi"]) < 45)
        & (abs(merged_rlf["varyingazi"]) >= 45)
        & (merged_rlf["selectivity_index"] > 0)
    ]

    peripheral_driver = merged_rlf[
        (abs(merged_rlf["fixedazi"]) > 45)
        & (abs(merged_rlf["varyingazi"]) <= 45)
        & (merged_rlf["selectivity_index"] > 0)
    ]

    stats_frontal_driver = anova_tukey(
        frontal_driver, val_col="change_in_response", group_col="relative_level"
    )
    stats_peripheral_driver = anova_tukey(
        peripheral_driver, val_col="change_in_response", group_col="relative_level"
    )

    frontal_driver_means = frontal_driver.groupby("relative_level")[
        "change_in_response"
    ].mean()
    peripheral_driver_means = peripheral_driver.groupby("relative_level")[
        "change_in_response"
    ].mean()

    frontal_driver_sems = frontal_driver.groupby("relative_level")[
        "change_in_response"
    ].sem()
    peripheral_driver_sems = peripheral_driver.groupby("relative_level")[
        "change_in_response"
    ].sem()

    relative_levels = np.unique(frontal_driver_means.index)

    frontal_driver_plot = ax.errorbar(
        relative_levels,
        frontal_driver_means,
        frontal_driver_sems,
        color=colors[0],
        label=f"Frontal Driver, p = {stats_frontal_driver['anova_p']:.3g}",
    )
    peripheral_driver_plot = ax.errorbar(
        relative_levels,
        peripheral_driver_means,
        peripheral_driver_sems,
        color="b",
        label=f"Lateral Driver, p = {stats_peripheral_driver['anova_p']:.3g}",
    )

    ax.set_xticks(relative_levels.astype(int), relative_levels.astype(int))

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.legend(loc="lower left", frameon=False, fontsize=9)

    ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: f"{x:.0%}"))

    ax.axhline(0, color="k", ls="--", zorder=-1)
    ax.set_xlabel("Relative Level [dB]")
    ax.set_ylabel("Change in\nSpike Rate")
    ax.set_ylim(-0.5, 0.95)

## Single sound rate level function by selectivity index

In [19]:
def figure_singlesound_rlf_selectivity_resubmit(
    single_rlf: pd.DataFrame,
    single_rlf_out: pd.DataFrame,
    ax=None,
    colors=c_forebrain,
    brackets: dict[tuple[int, int] : float] | None = None,
    anova_align="right",
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    if ax is None:
        fig, ax = plt.subplots(1, 1, sharey=True)

    INTENSITY_0_EQUALS_DB_SPL = +63.0

    single_in = single_rlf.copy()
    single_out = single_rlf_out.copy()

    frontal_driver = single_in[abs(single_in["azimuth"]) <= 45]
    peripheral_driver = single_in[abs(single_in["azimuth"]) > 45]

    frontal_competitor = single_out[abs(single_out["azimuth"]) <= 45]
    peripheral_competitor = single_out[abs(single_out["azimuth"]) > 45]

    merged_rlf1 = frontal_driver.join(
        peripheral_competitor, how="inner", rsuffix="_out"
    )
    merged_rlf1 = merged_rlf1.loc[
        merged_rlf1["intensity"] == merged_rlf1["intensity_out"]
    ]
    merged_rlf1["selectivity_index"] = (
        merged_rlf1[(merged_rlf1["intensity"] == -20)]["resp"]
        - merged_rlf1[(merged_rlf1["intensity"] == -20)]["resp_out"]
    ) / (
        merged_rlf1[(merged_rlf1["intensity"] == -20)]["resp"]
        + merged_rlf1[(merged_rlf1["intensity"] == -20)]["resp_out"]
    )

    merged_rlf2 = frontal_competitor.join(
        peripheral_driver, how="inner", rsuffix="_out"
    )
    merged_rlf2 = merged_rlf2.loc[
        merged_rlf2["intensity"] == merged_rlf2["intensity_out"]
    ]
    merged_rlf2["selectivity_index"] = (
        merged_rlf2[(merged_rlf2["intensity"] == -20)]["resp"]
        - merged_rlf2[(merged_rlf2["intensity"] == -20)]["resp_out"]
    ) / (
        merged_rlf2[(merged_rlf2["intensity"] == -20)]["resp"]
        + merged_rlf2[(merged_rlf2["intensity"] == -20)]["resp_out"]
    )

    tuning_index_1 = merged_rlf1.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "norm_resp",
            "norm_resp_out",
            "selectivity_index",
        ],
        axis=1,
    )
    tuning_index_2 = merged_rlf2.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "norm_resp",
            "norm_resp_out",
            "selectivity_index",
        ],
        axis=1,
    )

    merged_rlf = pd.concat([tuning_index_1, tuning_index_2])

    quantiles = np.quantile(merged_rlf["selectivity_index"], [0.25, 0.5, 0.75])

    competitor_selective = merged_rlf[merged_rlf["selectivity_index"] <= 0]

    driver_selective = merged_rlf[(merged_rlf["selectivity_index"] > 0)]

    print("Number of units:", np.unique(merged_rlf.index).size)
    print(
        "Number of driver units:",
        np.unique(
            merged_rlf[
                (merged_rlf["selectivity_index"] > 0) & (merged_rlf["intensity"] == -20)
            ].index
        ).size,
    )
    print(
        "Number of competitor units:",
        np.unique(
            merged_rlf[
                (merged_rlf["selectivity_index"] <= 0)
                & (merged_rlf["intensity"] == -20)
            ].index
        ).size,
    )

    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_rlf.index.values]))),
    )
    merged_rlf = merged_rlf.set_index("intensity")

    intensities = np.unique(merged_rlf.index) + INTENSITY_0_EQUALS_DB_SPL

    DRIVERALONE_driverselective_means = driver_selective.groupby("intensity")[
        "norm_resp"
    ].mean()
    DRIVERALONE_driverselective_sems = driver_selective.groupby("intensity")[
        "norm_resp"
    ].sem()

    COMPETITORONE_driverselective_means = driver_selective.groupby("intensity")[
        "norm_resp_out"
    ].mean()
    COMPETITORALONE_driverselective_sems = driver_selective.groupby("intensity")[
        "norm_resp_out"
    ].sem()

    display(len(np.unique(driver_selective.index)))

    print(np.asarray(DRIVERALONE_driverselective_means.values))
    print(COMPETITORONE_driverselective_means.values)

    ax.set_title("Frontal Driver Selective")
    ax.errorbar(
        intensities,
        DRIVERALONE_driverselective_means.values,
        DRIVERALONE_driverselective_sems.values,
        color=colors[0],
        label="Frontal stimulus alone",
    )
    ax.errorbar(
        intensities,
        COMPETITORONE_driverselective_means.values,
        COMPETITORALONE_driverselective_sems.values,
        color="b",
        label="Lateral stimulus alone",
    )

    ax.axvspan(
        -35 + INTENSITY_0_EQUALS_DB_SPL,
        -10 + INTENSITY_0_EQUALS_DB_SPL,
        color="pink",
        alpha=0.5,
        zorder=-1,
    )

    ax.set_xlabel("Average binaural level [dB]")
    ax.set_ylabel("Norm Resp.")
    ax.legend(loc="lower right", frameon=False)

    ax.set_ylim(bottom=0, top=0.80)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

## Extreme selectivity competition 

In [20]:
def extreme_selectivity_competition_resubmit(
    twostim_rlf: pd.DataFrame,
    single_rlf: pd.DataFrame,
    single_rlf_out: pd.DataFrame,
    axs=None,
    colors=c_forebrain,
):

    if axs is None:
        fig, axs = plt.subplots(1, 2)

    ## merge single rlf data and calculate selectivity index
    single_in = single_rlf.copy()
    single_out = single_rlf_out.copy()

    frontal_driver = single_in[abs(single_in["azimuth"]) <= 45]
    peripheral_driver = single_in[abs(single_in["azimuth"]) > 45]

    frontal_competitor = single_out[abs(single_out["azimuth"]) <= 45]
    peripheral_competitor = single_out[abs(single_out["azimuth"]) > 45]

    merged_rlf1 = frontal_driver.join(
        peripheral_competitor, how="inner", rsuffix="_out"
    )
    merged_rlf1 = merged_rlf1.loc[
        merged_rlf1["intensity"] == merged_rlf1["intensity_out"]
    ]
    merged_rlf1["selectivity_index"] = (
        merged_rlf1["resp"] - merged_rlf1["resp_out"]
    ) / (merged_rlf1["resp"] + merged_rlf1["resp_out"])

    merged_rlf2 = frontal_competitor.join(
        peripheral_driver, how="inner", rsuffix="_out"
    )
    merged_rlf2 = merged_rlf2.loc[
        merged_rlf2["intensity"] == merged_rlf2["intensity_out"]
    ]
    merged_rlf2["selectivity_index"] = (
        merged_rlf2["resp"] - merged_rlf2["resp_out"]
    ) / (merged_rlf2["resp"] + merged_rlf2["resp_out"])

    tuning_index_1 = merged_rlf1.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "resp",
            "selectivity_index",
        ],
        axis=1,
    )
    tuning_index_2 = merged_rlf2.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "resp_out",
            "selectivity_index",
        ],
        axis=1,
    )

    tuning_index_2 = tuning_index_2.rename(columns={"resp_out": "resp"})

    merged_1 = twostim_rlf.join(tuning_index_1, how="inner", rsuffix="_single")
    merged_1 = merged_1.loc[(merged_1["fixedintensity"] == merged_1["intensity"])]
    merged_1 = merged_1.loc[(merged_1["fixedazi"] == merged_1["azimuth"])]
    merged_1 = merged_1.loc[(merged_1["fixedele"] == merged_1["elevation"])]

    merged_1 = merged_1.loc[(merged_1["varyingazi"] == merged_1["azimuth_out"])]
    merged_1 = merged_1.loc[(merged_1["varyingele"] == merged_1["elevation_out"])]

    merged_2 = twostim_rlf.join(tuning_index_2, how="inner", rsuffix="_single")
    merged_2 = merged_2.loc[(merged_2["fixedintensity"] == merged_2["intensity"])]

    merged_2 = merged_2.loc[(merged_2["fixedazi"] == merged_2["azimuth_out"])]
    merged_2 = merged_2.loc[(merged_2["fixedele"] == merged_2["elevation_out"])]
    merged_2 = merged_2.loc[(merged_2["varyingazi"] == merged_2["azimuth"])]
    merged_2 = merged_2.loc[(merged_2["varyingele"] == merged_2["elevation"])]
    merged_rlf = pd.concat([merged_1, merged_2])

    quantiles = np.quantile(merged_rlf["selectivity_index"], [0.25, 0.5, 0.75])

    print(
        "Number of driver units:",
        np.unique(
            merged_rlf[
                (abs(merged_rlf["fixedazi"]) <= 45)
                & (merged_rlf["selectivity_index"] > 0.30)
            ].index
        ).size,
    )
    print(
        "Number of competitor units:",
        np.unique(
            merged_rlf[
                (abs(merged_rlf["fixedazi"]) <= 45)
                & (merged_rlf["selectivity_index"] <= -0.50)
            ].index
        ).size,
    )

    merged_rlf["change_in_response"] = (
        merged_rlf["resp"] - merged_rlf["resp_single"]
    ) / merged_rlf["resp_single"]

    merged_rlf.set_index("relative_level", inplace=True)

    ## sort data based on selectivity index
    driver_selective = merged_rlf[
        (abs(merged_rlf["fixedazi"]) <= 45) & (merged_rlf["selectivity_index"] > 0.30)
    ]
    competitor_selective = merged_rlf[
        (abs(merged_rlf["fixedazi"]) <= 45) & (merged_rlf["selectivity_index"] < -0.50)
    ]

    driver_means = driver_selective.groupby("relative_level")[
        "change_in_response"
    ].mean()
    competitor_means = competitor_selective.groupby("relative_level")[
        "change_in_response"
    ].mean()

    driver_sems = driver_selective.groupby("relative_level")["change_in_response"].sem()
    competitor_sems = competitor_selective.groupby("relative_level")[
        "change_in_response"
    ].sem()

    relative_levels = np.unique(driver_means.index)

    # fig, ax = plt.subplots(1,2)
    driver_plot = axs[0].errorbar(
        driver_means.index.values,
        driver_means,
        driver_sems,
        color="orange",
        label="Driver Selective",
    )
    competitor_plot = axs[1].errorbar(
        competitor_means.index.values,
        competitor_means,
        competitor_sems,
        color="b",
        label="Competitor Selective",
    )

    axs[0].yaxis.set_major_formatter(
        mpl.ticker.FuncFormatter(lambda x, pos: f"{x:.0%}")
    )
    axs[1].yaxis.set_major_formatter(
        mpl.ticker.FuncFormatter(lambda x, pos: f"{x:.0%}")
    )

    axs[0].axhline(0, ls=":", color="k")
    axs[0].set_xticks(relative_levels.astype(int), relative_levels.astype(int))
    axs[0].set_xlabel("Relative Level [dB]")
    axs[0].set_ylabel("Change in\nSpike Rate")
    axs[0].set_ylim(-0.60, 0.1)

    axs[1].axhline(0, ls=":", color="k")
    axs[1].set_xticks(relative_levels.astype(int), relative_levels.astype(int))
    axs[1].set_xlabel("Relative Level [dB]")
    axs[1].set_ylabel("Change in\nSpike Rate")

    axs[1].set_ylim(-0.10, 2.00)

    axs[0].text(
        0.02,
        0.20,
        f"Frontal Driver+Lateral competitor",
        ha="left",
        va="top",
        fontsize=10,
        transform=axs[0].transAxes,
    )

    axs[1].text(
        0.02,
        0.20,
        f"Frontal Driver+Lateral competitor",
        ha="left",
        va="top",
        fontsize=10,
        transform=axs[1].transAxes,
    )

## Modeling

In [21]:
def sigmoid(RL, b0, b1):
    return 1 / (1 + np.exp(-(b0 + b1 * RL)))

In [22]:
def logit(p):
    return np.log(p / (1 - p))

In [23]:
def get_sigmoid_parameters(switch_type):

    # Specify the relative level values where 90% and 10% of the max response are reached.
    if switch_type:
        relative_level_90, relative_level_10 = -10, -5
    else:
        relative_level_90, relative_level_10 = -15, 15

    # Use the constraints that the sigmoid function takes on 90% and 10% at the points above to find its parameters.
    p_90 = 0.9
    p_10 = 0.1

    A = np.array([[1, relative_level_90], [1, relative_level_10]])
    b = np.array([[logit(p_90)], [logit(p_10)]])

    x = np.linalg.solve(A, b)

    b0 = x[0][0]
    b1 = x[1][0]

    return b0, b1

In [24]:
def rlf_modeling_data(
    single_rlf: pd.DataFrame,
    single_rlf_out: pd.DataFrame,
):

    INTENSITY_0_EQUALS_DB_SPL = +63.0

    single_in = single_rlf.copy()
    single_out = single_rlf_out.copy()

    frontal_driver = single_in[abs(single_in["azimuth"]) <= 45]
    peripheral_driver = single_in[abs(single_in["azimuth"]) > 45]

    frontal_competitor = single_out[abs(single_out["azimuth"]) <= 45]
    peripheral_competitor = single_out[abs(single_out["azimuth"]) > 45]

    merged_rlf1 = frontal_driver.join(
        peripheral_competitor, how="inner", rsuffix="_out"
    )
    merged_rlf1 = merged_rlf1.loc[
        merged_rlf1["intensity"] == merged_rlf1["intensity_out"]
    ]
    merged_rlf1["selectivity_index"] = (
        merged_rlf1[(merged_rlf1["intensity"] == -20)]["resp"]
        - merged_rlf1[(merged_rlf1["intensity"] == -20)]["resp_out"]
    ) / (
        merged_rlf1[(merged_rlf1["intensity"] == -20)]["resp"]
        + merged_rlf1[(merged_rlf1["intensity"] == -20)]["resp_out"]
    )

    merged_rlf2 = frontal_competitor.join(
        peripheral_driver, how="inner", rsuffix="_out"
    )
    merged_rlf2 = merged_rlf2.loc[
        merged_rlf2["intensity"] == merged_rlf2["intensity_out"]
    ]
    merged_rlf2["selectivity_index"] = (
        merged_rlf2[(merged_rlf2["intensity"] == -20)]["resp"]
        - merged_rlf2[(merged_rlf2["intensity"] == -20)]["resp_out"]
    ) / (
        merged_rlf2[(merged_rlf2["intensity"] == -20)]["resp"]
        + merged_rlf2[(merged_rlf2["intensity"] == -20)]["resp_out"]
    )

    tuning_index_1 = merged_rlf1.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "norm_resp",
            "norm_resp_out",
            "selectivity_index",
        ],
        axis=1,
    )
    tuning_index_2 = merged_rlf2.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "norm_resp",
            "norm_resp_out",
            "selectivity_index",
        ],
        axis=1,
    )

    merged_rlf = pd.concat([tuning_index_1, tuning_index_2])

    quantiles = np.quantile(merged_rlf["selectivity_index"], [0.25, 0.5, 0.75])

    competitor_selective = merged_rlf[merged_rlf["selectivity_index"] <= 0]

    driver_selective = merged_rlf[(merged_rlf["selectivity_index"] > 0)]

    print("Number of units:", np.unique(merged_rlf.index).size)
    print(
        "Number of driver units:",
        np.unique(
            merged_rlf[
                (merged_rlf["selectivity_index"] > 0) & (merged_rlf["intensity"] == -20)
            ].index
        ).size,
    )
    print(
        "Number of competitor units:",
        np.unique(
            merged_rlf[
                (merged_rlf["selectivity_index"] <= 0)
                & (merged_rlf["intensity"] == -20)
            ].index
        ).size,
    )

    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_rlf.index.values]))),
    )
    merged_rlf = merged_rlf.set_index("intensity")

    intensities = np.unique(merged_rlf.index) + INTENSITY_0_EQUALS_DB_SPL

    DRIVERALONE_driverselective_means = driver_selective.groupby("intensity")[
        "norm_resp"
    ].mean()
    DRIVERALONE_driverselective_sems = driver_selective.groupby("intensity")[
        "norm_resp"
    ].sem()

    COMPETITORONE_driverselective_means = driver_selective.groupby("intensity")[
        "norm_resp_out"
    ].mean()
    COMPETITORALONE_driverselective_sems = driver_selective.groupby("intensity")[
        "norm_resp_out"
    ].sem()

    return (
        intensities,
        DRIVERALONE_driverselective_means.values,
        COMPETITORONE_driverselective_means.values,
    )

In [25]:
def competition_modeling_data(
    twostim_rlf: pd.DataFrame,
    single_rlf: pd.DataFrame,
    single_rlf_out: pd.DataFrame,
):

    ## merge single rlf data and calculate selectivity index
    single_in = single_rlf.copy()
    single_out = single_rlf_out.copy()

    frontal_driver = single_in[abs(single_in["azimuth"]) <= 45]
    peripheral_driver = single_in[abs(single_in["azimuth"]) > 45]

    frontal_competitor = single_out[abs(single_out["azimuth"]) <= 45]
    peripheral_competitor = single_out[abs(single_out["azimuth"]) > 45]

    merged_rlf1 = frontal_driver.join(
        peripheral_competitor, how="inner", rsuffix="_out"
    )
    merged_rlf1 = merged_rlf1.loc[
        merged_rlf1["intensity"] == merged_rlf1["intensity_out"]
    ]
    merged_rlf1["selectivity_index"] = (
        merged_rlf1["resp"] - merged_rlf1["resp_out"]
    ) / (merged_rlf1["resp"] + merged_rlf1["resp_out"])

    merged_rlf2 = frontal_competitor.join(
        peripheral_driver, how="inner", rsuffix="_out"
    )
    merged_rlf2 = merged_rlf2.loc[
        merged_rlf2["intensity"] == merged_rlf2["intensity_out"]
    ]
    merged_rlf2["selectivity_index"] = (
        merged_rlf2["resp"] - merged_rlf2["resp_out"]
    ) / (merged_rlf2["resp"] + merged_rlf2["resp_out"])

    tuning_index_1 = merged_rlf1.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "resp",
            "selectivity_index",
        ],
        axis=1,
    )
    tuning_index_2 = merged_rlf2.filter(
        [
            "azimuth",
            "elevation",
            "azimuth_out",
            "elevation_out",
            "intensity",
            "resp_out",
            "selectivity_index",
        ],
        axis=1,
    )

    tuning_index_2 = tuning_index_2.rename(columns={"resp_out": "resp"})

    merged_1 = twostim_rlf.join(tuning_index_1, how="inner", rsuffix="_single")
    merged_1 = merged_1.loc[(merged_1["fixedintensity"] == merged_1["intensity"])]
    merged_1 = merged_1.loc[(merged_1["fixedazi"] == merged_1["azimuth"])]
    merged_1 = merged_1.loc[(merged_1["fixedele"] == merged_1["elevation"])]

    merged_1 = merged_1.loc[(merged_1["varyingazi"] == merged_1["azimuth_out"])]
    merged_1 = merged_1.loc[(merged_1["varyingele"] == merged_1["elevation_out"])]

    merged_2 = twostim_rlf.join(tuning_index_2, how="inner", rsuffix="_single")
    merged_2 = merged_2.loc[(merged_2["fixedintensity"] == merged_2["intensity"])]

    merged_2 = merged_2.loc[(merged_2["fixedazi"] == merged_2["azimuth_out"])]
    merged_2 = merged_2.loc[(merged_2["fixedele"] == merged_2["elevation_out"])]
    merged_2 = merged_2.loc[(merged_2["varyingazi"] == merged_2["azimuth"])]
    merged_2 = merged_2.loc[(merged_2["varyingele"] == merged_2["elevation"])]
    merged_rlf = pd.concat([merged_1, merged_2])

    quantiles = np.quantile(merged_rlf["selectivity_index"], [0.25, 0.5, 0.75])

    print(
        "Number of frontal driver units + frontal driver:",
        np.unique(
            merged_rlf[
                (abs(merged_rlf["fixedazi"]) <= 45)
                & (abs(merged_rlf["varyingazi"]) > 45)
                & (merged_rlf["selectivity_index"] > 0)
            ].index
        ).size,
    )
    print(
        "Number of frontal driver units + lateral driver:",
        np.unique(
            merged_rlf[
                (abs(merged_rlf["fixedazi"]) > 45)
                & (abs(merged_rlf["varyingazi"]) <= 45)
                & (merged_rlf["selectivity_index"] > 0)
            ].index
        ).size,
    )

    ##lateral driver selective units
    print(
        "Number of lateral driver units + frontal driver:",
        np.unique(
            merged_rlf[
                (abs(merged_rlf["fixedazi"]) <= 45)
                & (abs(merged_rlf["varyingazi"]) > 45)
                & (merged_rlf["selectivity_index"] <= 0)
            ].index
        ).size,
    )
    print(
        "Number of lateral driver units + lateral driver:",
        np.unique(
            merged_rlf[
                (abs(merged_rlf["fixedazi"]) > 45)
                & (abs(merged_rlf["varyingazi"]) <= 45)
                & (merged_rlf["selectivity_index"] <= 0)
            ].index
        ).size,
    )

    merged_rlf["change_in_response"] = (
        merged_rlf["resp"] - merged_rlf["resp_single"]
    ) / merged_rlf["resp_single"]

    merged_rlf.set_index("relative_level", inplace=True)

    ##sort data based on frontal driver + lateral competitor
    frontal_driver = merged_rlf[
        (abs(merged_rlf["fixedazi"]) < 45)
        & (abs(merged_rlf["varyingazi"]) >= 45)
        & (merged_rlf["selectivity_index"] > 0)
    ]

    peripheral_driver = merged_rlf[
        (abs(merged_rlf["fixedazi"]) > 45)
        & (abs(merged_rlf["varyingazi"]) <= 45)
        & (merged_rlf["selectivity_index"] > 0)
    ]

    frontal_driver_means = frontal_driver.groupby("relative_level")[
        "change_in_response"
    ].mean()
    peripheral_driver_means = peripheral_driver.groupby("relative_level")[
        "change_in_response"
    ].mean()

    frontal_driver_sems = frontal_driver.groupby("relative_level")[
        "change_in_response"
    ].sem()
    peripheral_driver_sems = peripheral_driver.groupby("relative_level")[
        "change_in_response"
    ].sem()

    relative_levels = np.unique(frontal_driver_means.index)

    return relative_levels, frontal_driver_means.values, peripheral_driver_means.values

In [26]:
def model(relative_resp, axs=None):
    driver_level = 43
    competitor_level = driver_level + relative_level

    r_d = rate_driver_4b[sound_level == driver_level]
    r_c = rate_competitor_4b[np.in1d(sound_level, competitor_level)]
    relative_response = relative_resp
    switch_type = True
    b0, b1 = get_sigmoid_parameters(switch_type)

    gain_d = 0.7 * sigmoid(relative_level, b0, b1) + 0.3

    gain_c = r_d * (1 + relative_response - gain_d) / r_c

    axs[0].plot(relative_level, gain_c, color="b")
    axs[0].plot(relative_level, gain_d, color="orange")
    axs[0].axhline(1.0, ls=":", color="k")
    axs[0].set_ylim(-0.2, 2)

    axs[0].set_xlabel("Relative level [dB]")
    axs[0].set_ylabel("Gain on input")
    axs[0].legend(
        ["Competitor position", "Driver position"], frameon=False, loc="upper left"
    )
    axs[0].set_xticks(
        [int(lev) for lev in relative_level], [int(lev) for lev in relative_level]
    )

    r = r_d * gain_d + r_c * gain_c

    percent_change_d = (r - r_d) / r_d  # * 100

    axs[1].plot(relative_level, percent_change_d, color="k")
    axs[1].axhline(0, ls=":", color="k")

    axs[1].set_ylim(-1, 1)
    axs[1].set_xticks(
        [int(lev) for lev in relative_level], [int(lev) for lev in relative_level]
    )
    axs[1].set_xlabel("Relative level [dB]")
    axs[1].set_ylabel("Change in\nSpike Rate")
    axs[1].yaxis.set_major_formatter(
        mpl.ticker.FuncFormatter(lambda x, pos: f"{x:.0%}")
    )

## SRF correlation distribution

In [27]:
def srf_correlation_distribution(single_srf_df, twostim_ccg_df, axs=None):
    if axs is None:
        fig, axs = plt.subplots(1, 1)
    srf_df = single_srf_df.copy()
    twostim_ccg_df = twostim_ccg_df.copy()

    twostim_ccg_df = twostim_ccg_df[twostim_ccg_df["corr_type"] == "cross_region"]

    srf_corr_df = []
    for index in np.unique(twostim_ccg_df.index):
        try:
            date = index[0]
            owl = index[1]
            channel_ot = index[-2]
            channel_fore = index[-1]

            channel_ot_srf = srf_df.loc[date, owl, channel_ot]
            channel_fore_srf = srf_df.loc[date, owl, channel_fore]

            srf_corr = np.corrcoef(
                channel_ot_srf["norm_resp"], channel_fore_srf["norm_resp"]
            )[0][1]

            tmp = {
                "date": date,
                "owl": owl,
                "channel1": channel_ot,
                "channel2": channel_fore,
                "srf_corr": srf_corr,
            }
            srf_corr_df.append(tmp)
        except:
            pass
    srf_corr_df = pd.DataFrame(srf_corr_df)
    srf_corr_df = srf_corr_df.set_index(["date", "owl", "channel1", "channel2"])
    srf_corr_df = srf_corr_df.filter(
        [
            "srf_corr",
        ],
        axis=1,
    )

    merge = twostim_ccg_df.join(srf_corr_df, how="inner", rsuffix="_srf").set_index(
        "relative_level"
    )
    ##for just one level
    merge = merge.loc[0, "srf_corr"]

    mean_srf_corr = np.mean(merge)
    std_srf_corr = np.std(merge)
    confidence_interval = confidence_intervale(merge)

    print(f"Mean {mean_srf_corr:.2f}")
    # print(f'CI {confidence_interval:.2f}')
    print(f"STD {std_srf_corr:.2f}")

    quantiles = np.quantile(merge, [0.25, 0.50, 0.75])
    colors = ["slategrey", "royalblue", "blue", "navy"]

    bins = np.arange(-1.0, 1.0, 0.01)
    # plt.figure()
    axs.hist(
        merge[(merge <= quantiles[0])], bins, facecolor=colors[0], edgecolor="None"
    )
    axs.hist(
        merge[(merge > quantiles[0]) & ((merge <= quantiles[1]))],
        bins,
        facecolor=colors[1],
        edgecolor="None",
    )
    axs.hist(
        merge[(merge > quantiles[1]) & ((merge <= quantiles[2]))],
        bins,
        facecolor=colors[2],
        edgecolor="None",
    )
    axs.hist(merge[(merge > quantiles[2])], bins, facecolor=colors[3], edgecolor="None")

    axs.set_ylabel("Number of unit pairs")
    axs.set_xlabel("Spatial Tuning Correlation between OT and nRt")
    axs.spines["right"].set_visible(False)
    axs.spines["top"].set_visible(False)

## SRF correlation by frontal OT and peripheral OT

In [28]:
def srf_correlation_frontalperiperpheral(single_srf_df, twostim_ccg_df, axs=None):
    if axs is None:
        fig, axs = plt.subplots(1, 1)
    srf_df = single_srf_df.copy()
    twostim_ccg_df = twostim_ccg_df.copy()

    twostim_ccg_df = twostim_ccg_df[twostim_ccg_df["corr_type"] == "cross_region"]

    srf_corr_df = []
    for index in np.unique(twostim_ccg_df.index):
        try:
            date = index[0]
            owl = index[1]
            channel_ot = index[-2]
            channel_fore = index[-1]

            channel_ot_srf = srf_df.loc[date, owl, channel_ot]
            channel_fore_srf = srf_df.loc[date, owl, channel_fore]

            srf_corr = np.corrcoef(
                channel_ot_srf["norm_resp"], channel_fore_srf["norm_resp"]
            )[0][1]

            tmp = {
                "date": date,
                "owl": owl,
                "channel1": channel_ot,
                "channel2": channel_fore,
                "srf_corr": srf_corr,
            }
            srf_corr_df.append(tmp)
        except:
            pass
    srf_corr_df = pd.DataFrame(srf_corr_df)
    zz = srf_corr_df.copy()
    srf_corr_df = srf_corr_df.set_index(["date", "owl", "channel1", "channel2"])

    unique_ot = zz.set_index(["date", "owl", "channel1"])
    unique_forebrain = zz.set_index(["date", "owl", "channel2"])
    print("Number of OT units:", np.unique(unique_ot.index).size)
    print("Number of Forebrain units:", np.unique(unique_forebrain.index).size)

    srf_corr_df = srf_corr_df.filter(
        [
            "srf_corr",
        ],
        axis=1,
    )

    merge = twostim_ccg_df.join(srf_corr_df, how="inner", rsuffix="_srf")
    merge = merge[merge["relative_level"] == 0]

    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merge.index.values]))),
    )

    bins = np.arange(-1.0, 1.0, 0.01)
    frontal_ot_units = merge[(abs(merge["fixedazi"]) <= 45)]["srf_corr"]
    peripheral_ot_units = merge[(abs(merge["fixedazi"]) > 45)]["srf_corr"]

    CI_frontal = confidence_intervale(frontal_ot_units)
    CI_peripheral = confidence_intervale(peripheral_ot_units)

    tval, p = scipy.stats.ttest_ind(frontal_ot_units, peripheral_ot_units)

    print(tval, p)
    print(f"frontal stats: {np.mean(frontal_ot_units)} +/- {CI_frontal}")
    print(f"peripheral stats: {np.mean(peripheral_ot_units)} +/- {CI_peripheral}")

    axs.hist(
        peripheral_ot_units,
        bins,
        edgecolor="r",
        facecolor="r",
        alpha=0.5,
        density=True,
        label="lateral",
    )
    axs.hist(
        frontal_ot_units,
        bins,
        edgecolor="b",
        facecolor="b",
        alpha=0.5,
        density=True,
        label="frontal",
    )
    axs.plot(np.mean(frontal_ot_units), 8, "v", color="b")
    axs.plot(np.mean(peripheral_ot_units), 8, "v", color="r")
    axs.legend(frameon=False)
    axs.set_xlabel("Spatial tuning correlation\nbetween OT and Forebrain")
    axs.set_ylabel("Density")
    axs.spines["right"].set_visible(False)
    axs.spines["top"].set_visible(False)
    axs.text(
        0,
        1,
        f"p = {p:.3g}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=axs.transAxes
        + mpl.transforms.ScaledTranslation(
            +2 / 72,
            0 / 72,
            axs.figure.dpi_scale_trans,
        ),
    )
    axs.set_ylim(bottom=0, top=10)

## VS competitior as a function of VS driver

In [29]:
def figure_stimphase_driver_competitor(
    df: pd.DataFrame,
    axs=None,
    colors=[c_forebrain_am55, c_forebrain_am75],
    axs_pol=None,
):
    if axs is None:
        fig, axs = plt.subplots(1, df["relative_level"].unique().size)
    axs = axs.flatten()

    df = df.set_index("relative_level")

    for k, level in enumerate(np.unique(df.index)):
        ax = axs[k]

        level_rayleigh_p = df.loc[level, "fixedstim_plv_p"]

        level_rayleigh_p_varying = df.loc[level, "varyingstim_plv_p"]
        level_plv_driver = df.loc[level, "fixedstim_plv"]
        level_plv_competitor = df.loc[level, "varyingstim_plv"]

        # Filter phases:
        mask = level_rayleigh_p < 0.05
        #        mask_comp = level_rayleigh_p_varying < 0.05

        ax.plot(
            level_plv_driver[~mask],
            level_plv_competitor[~mask],
            ls="none",
            marker=".",
            color=".5",
            mec="none",
        )

        ax.plot(
            level_plv_driver[mask],
            level_plv_competitor[mask],
            ls="none",
            marker=".",
            color=colors[0][0],
            mec="none",
            zorder=-1,
        )

        ax.plot([0, 0.5], [0, 0.5], color="k", lw=0.8, ls="--")
        ax.text(
            0,
            1,
            f"{np.mean(mask):.0%}",
            ha="left",
            va="top",
            fontsize=8,
            color=colors[0][0],
            transform=ax.transAxes
            + mpl.transforms.ScaledTranslation(
                +1 / 72,
                -1 / 72,
                ax.figure.dpi_scale_trans,
            ),
        )

        ax.set_title(f"{level:.0f}", fontsize=10)

        ax.set_xlim(left=0, right=0.5)
        ax.set_xticks([0, 0.5], [])
        ax.set_xticks([0.25], minor=True)

        ax.set_ylim(bottom=0, top=0.5)
        ax.set_yticks([0, 0.5], [])
        ax.set_yticks([0.25], minor=True)
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)
        if k == 0:
            ax.set_yticklabels(["0", "0.5"])
            ax.set_xticklabels(["0", "0.5"])
        if k == 3:
            ax.set_xlabel("VS Driver")

## AM VS change with competition

In [30]:
def figure_stimphase_switch(
    df: pd.DataFrame,
    ax=None,
    colors=[c_forebrain_am55, c_forebrain_am75],
    axs_pol=None,
):
    if ax is None:
        fig, ax = plt.subplots(1, 1)

    driver_vs = []
    competitor_vs = []
    for index in df.index:
        sub = df.loc[index]
        driver_plv = list(sub.fixedstim_plv.values)
        competitor_plv = list(sub.varyingstim_plv.values)
        driver_vs.append(driver_plv)
        if len(driver_plv) > 6:
            display(sub)
        competitor_vs.append(competitor_plv)

    print([len(x) for x in driver_vs])

    ax.errorbar(
        np.unique(df["relative_level"]),
        np.mean(driver_vs, axis=0),
        scipy.stats.sem(driver_vs),
        color=colors[0][0],
        label="Driver",
    )
    ax.errorbar(
        np.unique(df["relative_level"]),
        np.mean(competitor_vs, axis=0),
        scipy.stats.sem(competitor_vs),
        color=colors[1][0],
        label="Competitor",
    )
    ax.set_xticks(
        np.unique(df["relative_level"]),
        [int(val) for val in np.unique(df["relative_level"])],
    )
    ax.set_xlabel("Relative Level [dB]")
    ax.set_ylabel("Vector Strength")
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

    fixfreq = df.iloc[0]["fixed_modulation_frequency"]
    varfreq = df.iloc[0]["varying_modulation_frequency"]

    mylabels = [f"{fixfreq}Hz driver", f"{varfreq}Hz competitor"]

    ax.legend(labels=mylabels, frameon=False, loc="lower left")

## Dual region xcorr by srf correlation

In [31]:
def fig_dual_region_xcorr_grid(single_srf_df, twostim_ccg_df, axs=None, caxs=None):
    if axs is None:
        fig, axs = plt.subplots(4, 6)
    # axs = axs.flatten()

    srf_df = single_srf_df.copy()
    twostim_ccg_df = twostim_ccg_df.copy()

    twostim_ccg_df = twostim_ccg_df[twostim_ccg_df["corr_type"] == "cross_region"]

    srf_corr_df = []
    for index in np.unique(twostim_ccg_df.index):
        try:
            date = index[0]
            owl = index[1]
            channel_ot = index[-2]
            channel_fore = index[-1]

            channel_ot_srf = srf_df.loc[date, owl, channel_ot]
            channel_fore_srf = srf_df.loc[date, owl, channel_fore]

            srf_corr = np.corrcoef(
                channel_ot_srf["norm_resp"], channel_fore_srf["norm_resp"]
            )[0][1]

            tmp = {
                "date": date,
                "owl": owl,
                "channel1": channel_ot,
                "channel2": channel_fore,
                "srf_corr": srf_corr,
            }
            srf_corr_df.append(tmp)
        except:
            pass
    srf_corr_df = pd.DataFrame(srf_corr_df)
    srf_corr_df = srf_corr_df.set_index(["date", "owl", "channel1", "channel2"])
    srf_corr_df = srf_corr_df.filter(
        [
            "srf_corr",
        ],
        axis=1,
    )

    merge = twostim_ccg_df.join(srf_corr_df, how="inner", rsuffix="_srf")

    quantiles = np.quantile(merge["srf_corr"], [0.25, 0.50, 0.75])

    low_corr = merge[(merge["srf_corr"] <= quantiles[0])].set_index("relative_level")
    lower_mid_corr = merge[
        (merge["srf_corr"] > quantiles[0]) & ((merge["srf_corr"] <= quantiles[1]))
    ].set_index("relative_level")
    higher_mid_corr = merge[
        (merge["srf_corr"] > quantiles[1]) & ((merge["srf_corr"] <= quantiles[2]))
    ].set_index("relative_level")
    high_corr = merge[(merge["srf_corr"] > quantiles[2])].set_index("relative_level")

    merge = merge.set_index("relative_level")

    psth_len = (merge.iloc[0]["ccg"].size + 1) / 2
    lags = scipy.signal.correlation_lags(psth_len, psth_len) / 1000
    lags_mask = np.abs(lags) <= 0.95
    lags_mask_std = np.abs(lags) >= 0.825

    relative_levels = np.unique(merge.index)

    groupdata = [low_corr, lower_mid_corr, higher_mid_corr, high_corr]
    label = ["lowcorr", "lowermid", "highermid", "high"]
    colors = ["slategrey", "royalblue", "blue", "navy"]
    display(high_corr.head(5))
    for ndata, data in enumerate(groupdata):
        for k, relative_level in enumerate(relative_levels):
            mean_ccg = np.mean(np.vstack(data.loc[relative_level, "ccg"]), axis=0)[::-1]

            mean_n_ccg = data.loc[relative_level, "ccg"].count()

            if relative_level == -15:
                peak_properties = get_peak(mean_ccg, lags)
                peak_time = round(peak_properties["peak_lag"], 3)
                print(label[ndata], relative_level, peak_time)

            axs[ndata, k].plot(
                lags[lags_mask], mean_ccg[lags_mask], color=colors[ndata], ls="-", lw=1
            )

            axs[ndata, k].text(
                0.5,
                0.95,
                f"{mean_n_ccg}",
                ha="center",
                va="top",
                fontsize=8,
                transform=axs[ndata, k].transAxes,
            )
            axs[ndata, k].spines["top"].set_visible(False)
            axs[ndata, k].spines["right"].set_visible(False)
            # axs[k].axvline(0, ls = ':', color = 'r')
            axs[ndata, k].set_ylim(bottom=-1.5e-6, top=6e-6)
            axs[ndata, k].set_xlim(left=-0.950, right=0.950)
            axs[ndata, 0].set_ylabel("Coinc./spk", labelpad=12)

            if ndata == 0:
                axs[ndata, k].set_title(relative_level)
                axs[ndata, 0].text(
                    0.80,
                    0.80,
                    f"OT to\nnRt",
                    ha="center",
                    va="top",
                    fontsize=8,
                    transform=axs[ndata, k].transAxes,
                )
                axs[ndata, 0].text(
                    0.25,
                    0.80,
                    f"nRt to\nOT",
                    ha="center",
                    va="top",
                    fontsize=8,
                    transform=axs[ndata, k].transAxes,
                )

                axs[ndata, 0].axvline(0, ls=":", color="k", lw=0.4, zorder=-1)

        for ax in axs[ndata, 1:]:
            ax.set_yticks([])
            ax.set_yticks([], minor=True)
            ax.spines["left"].set_visible(False)

        if ndata == 0:
            axs[0, 0].set_xticks([-0.95, 0, +0.95], ["-950", "0", "+950"], fontsize=8)
            axs[0, 0].set_xlabel("Lags [ms]", fontsize=8, labelpad=0)
            axs[0, 0].set_ylabel("Coinc./spk", labelpad=12)
            for ax in axs[ndata, 1:]:
                ax.set_xticks([-0.95, 0, 0.95], [])

        elif ndata != 0:
            for ax in axs[ndata, :]:
                ax.set_xticks([-0.95, 0, 0.95], [])

        def formatter_y(x, pos):
            if x > 0:
                return f"${{{x*1e6:.0f}}}$"
                # return f"$10^{{{np.log10(x):.0f}}}$"
            else:
                return "0"

        axs[ndata, 0].yaxis.set_major_formatter(mpl.ticker.FuncFormatter(formatter_y))
        axs[ndata, 0].set_yticks(np.array([0, 2.5, 5]) * 1e-6, minor=True)
        axs[ndata, 0].text(
            0,
            1,
            "$10^{-6}$",
            ha="right",
            va="center",
            fontsize=8,
            transform=axs[ndata, 0].transAxes
            + mpl.transforms.ScaledTranslation(
                -2 / 72,
                0,
                axs[ndata, 0].figure.dpi_scale_trans,
            ),
        )
    cmap = mpl.colors.ListedColormap(colors[::-1]).with_extremes(
        under=colors[0], over=colors[-1]
    )

    qrt = [0, 0.25, 0.5, 0.75, 1.0]
    qrt = [1.0, 0.75, 0.5, 0.25, 0]
    norm = mpl.colors.BoundaryNorm(qrt, cmap.N)
    plt.colorbar(
        mpl.cm.ScalarMappable(cmap=cmap, norm=norm),
        cax=caxs,
        spacing="proportional",
        label="Spatial Tuning Correlation Quartile",
    )

## Dual Region CSI by relative level

In [32]:
def csi_boxplot_dualregion(
    single_srf_df,
    twostim_ccg_df,
    single_ccg_df,
    ax=None,
    brackets: dict[tuple[int, int] : float] | None = None,
    anova_align="right",
):
    srf_df = single_srf_df.copy()
    twostim_ccg_df = twostim_ccg_df.copy()
    single_ccg_df = single_ccg_df.copy()

    if ax == None:
        fig, ax = plt.subplots(1, len(np.unique(twostim_ccg_df["relative_level"])))
    # axs.flatten()

    twostim_ccg_df = twostim_ccg_df[twostim_ccg_df["corr_type"] == "cross_region"]
    single_ccg_df = single_ccg_df[single_ccg_df["corr_type"] == "cross_region"]

    srf_corr_df = []
    for index in np.unique(twostim_ccg_df.index):
        try:
            date = index[0]
            owl = index[1]
            channel_ot = index[-2]
            channel_fore = index[-1]

            channel_ot_srf = srf_df.loc[date, owl, channel_ot]
            channel_fore_srf = srf_df.loc[date, owl, channel_fore]

            srf_corr = np.corrcoef(
                channel_ot_srf["norm_resp"], channel_fore_srf["norm_resp"]
            )[0][1]

            tmp = {
                "date": date,
                "owl": owl,
                "channel1": channel_ot,
                "channel2": channel_fore,
                "srf_corr": srf_corr,
            }
            srf_corr_df.append(tmp)
        except:
            pass
    srf_corr_df = pd.DataFrame(srf_corr_df)
    srf_corr_df = srf_corr_df.set_index(["date", "owl", "channel1", "channel2"])
    srf_corr_df = srf_corr_df.filter(
        [
            "srf_corr",
        ],
        axis=1,
    )

    merge = twostim_ccg_df.join(srf_corr_df, how="inner", rsuffix="_srf")

    merge = merge.join(
        single_ccg_df,
        on=["date", "owl", "channel1", "channel2"],
        how="inner",
        rsuffix="_single",
    )

    merge = merge.loc[
        (merge["fixedintensity"] == merge["intensity"])
        & (merge["fixedazi"] == merge["azimuth"])
        & (merge["fixedele"] == merge["elevation"])
    ]
    print("Number of units:", np.unique(merge.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merge.index.values]))),
    )
    merge["csi"] = (merge["xcorr_peak"] - merge["xcorr_peak_single"]) / (
        merge["xcorr_peak"] + merge["xcorr_peak_single"]
    )

    quantiles = np.quantile(merge["srf_corr"], [0.25, 0.50, 0.75])

    low_corr = merge[(merge["srf_corr"] <= quantiles[0])].set_index("relative_level")
    lower_mid_corr = merge[
        (merge["srf_corr"] > quantiles[0]) & ((merge["srf_corr"] <= quantiles[1]))
    ].set_index("relative_level")
    higher_mid_corr = merge[
        (merge["srf_corr"] > quantiles[1]) & ((merge["srf_corr"] <= quantiles[2]))
    ].set_index("relative_level")
    high_corr = merge[(merge["srf_corr"] > quantiles[2])].set_index("relative_level")

    merge = merge.set_index("relative_level")

    relative_levels = np.unique(merge.index)

    groupdata = [
        merge.loc[relative_level, "csi"].values for relative_level in relative_levels
    ]

    bp = ax.boxplot(
        groupdata,
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
    )

    ax.set_xticklabels(relative_levels.astype(int))

    stats = anova_tukey(merge, val_col="csi", group_col="relative_level")
    colors = ["orange"] * len(relative_levels)
    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)

    for median in bp["medians"]:
        median.set(color="white", linewidth=2)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate(groupdata):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker=".",
                markersize=2,
                markeredgewidth=0.0,
                zorder=5,
            )

    if True:  # show_brackets
        brackets_shrink = 0.85
        if brackets is None:
            brackets = {(k, k + 1): 0.8 for k in range(len(stats["tukey"]) - 1)}
            brackets_shrink = 0.8
        for (idx_left, idx_right), y in brackets.items():
            plot_bracket(
                ax,
                left=idx_left + 1,
                right=idx_right + 1,
                text=f"{stats['tukey'].iloc[idx_left, idx_right]:.3f}"[1:],
                y=y,
                shrink=brackets_shrink,
            )

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    if anova_align == "right":
        ax.text(
            1,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="right",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes,
        )
    else:
        ax.text(
            0,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="left",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes
            + mpl.transforms.ScaledTranslation(
                +2 / 72,
                0 / 72,
                ax.figure.dpi_scale_trans,
            ),
        )

    # ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: f"{x:.0%}"))

    ax.axhline(0, color="k", ls="--", zorder=-1)
    ax.set_xlabel("Relative Level [dB]")
    ax.set_ylabel("CSI")
    ax.set_ylim(-1, 1.45)

## Dual Region CSI by SRF Corr

In [33]:
def csi_boxplot_dualregion_srfcorr(
    single_srf_df,
    twostim_ccg_df,
    single_ccg_df,
    ax=None,
    brackets: dict[tuple[int, int] : float] | None = None,
    anova_align="right",
):
    srf_df = single_srf_df.copy()
    twostim_ccg_df = twostim_ccg_df.copy()
    single_ccg_df = single_ccg_df.copy()

    if ax == None:
        fig, ax = plt.subplots(1, len(np.unique(twostim_ccg_df["relative_level"])))

    twostim_ccg_df = twostim_ccg_df[twostim_ccg_df["corr_type"] == "cross_region"]
    single_ccg_df = single_ccg_df[single_ccg_df["corr_type"] == "cross_region"]

    srf_corr_df = []
    for index in np.unique(twostim_ccg_df.index):
        try:
            date = index[0]
            owl = index[1]
            channel_ot = index[-2]
            channel_fore = index[-1]

            channel_ot_srf = srf_df.loc[date, owl, channel_ot]
            channel_fore_srf = srf_df.loc[date, owl, channel_fore]

            srf_corr = np.corrcoef(
                channel_ot_srf["norm_resp"], channel_fore_srf["norm_resp"]
            )[0][1]

            tmp = {
                "date": date,
                "owl": owl,
                "channel1": channel_ot,
                "channel2": channel_fore,
                "srf_corr": srf_corr,
            }
            srf_corr_df.append(tmp)
        except:
            pass
    srf_corr_df = pd.DataFrame(srf_corr_df)
    srf_corr_df = srf_corr_df.set_index(["date", "owl", "channel1", "channel2"])
    srf_corr_df = srf_corr_df.filter(
        [
            "srf_corr",
        ],
        axis=1,
    )

    merge = twostim_ccg_df.join(srf_corr_df, how="inner", rsuffix="_srf")

    merge = merge.join(
        single_ccg_df,
        on=["date", "owl", "channel1", "channel2"],
        how="inner",
        rsuffix="_single",
    )

    merge = merge.loc[
        (merge["fixedintensity"] == merge["intensity"])
        & (merge["fixedazi"] == merge["azimuth"])
        & (merge["fixedele"] == merge["elevation"])
    ]
    print("Number of units:", np.unique(merge.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merge.index.values]))),
    )
    merge["csi"] = (merge["xcorr_peak"] - merge["xcorr_peak_single"]) / (
        merge["xcorr_peak"] + merge["xcorr_peak_single"]
    )

    quantiles = np.quantile(merge["srf_corr"], [0.25, 0.50, 0.75])

    low_corr = merge[(merge["srf_corr"] <= quantiles[0])].set_index("relative_level")
    lower_mid_corr = merge[
        (merge["srf_corr"] > quantiles[0]) & ((merge["srf_corr"] <= quantiles[1]))
    ].set_index("relative_level")
    higher_mid_corr = merge[
        (merge["srf_corr"] > quantiles[1]) & ((merge["srf_corr"] <= quantiles[2]))
    ].set_index("relative_level")
    high_corr = merge[(merge["srf_corr"] > quantiles[2])].set_index("relative_level")

    srf_corr_groups = [low_corr, lower_mid_corr, higher_mid_corr, high_corr]
    srf_corr_labels = ["<25 %", "25-50th %", "50-75th %", ">75th %"]

    relative_levels = np.unique(high_corr.index)

    groupdata = {group: [] for group in srf_corr_labels}
    stats_df = []
    for ndata, data in enumerate(srf_corr_groups):
        plot_values = data.loc[[-15, -10], "csi"].values
        groupdata[srf_corr_labels[ndata]] = plot_values
        for val in plot_values:
            tmp = {"csi": val, "tuning_index": srf_corr_labels[ndata]}
            stats_df.append(tmp)
    stats_df = pd.DataFrame(stats_df).set_index("tuning_index")

    plot_data = [values for values in groupdata.values()]

    bp = ax.boxplot(
        plot_data,
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
    )

    ax.set_xticklabels(srf_corr_labels, fontsize=9)
    # stats = scipy.stats.f_oneway(*plot_data)
    stats = anova_tukey(stats_df, val_col="csi", group_col="tuning_index")
    colors = ["slategrey", "royalblue", "blue", "navy"]
    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)

    for median in bp["medians"]:
        median.set(color="white", linewidth=2)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate(plot_data):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker=".",
                markersize=2,
                markeredgewidth=0.0,
                zorder=5,
            )

    if True:  # show_brackets
        brackets_shrink = 1.0
        for (idx_left, idx_right), y in brackets.items():
            plot_bracket(
                ax,
                left=idx_left + 1,
                right=idx_right + 1,
                text=f"{stats['tukey'].iloc[idx_left, idx_right]:.3f}"[1:],
                y=y,
                shrink=brackets_shrink,
            )

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    if anova_align == "right":
        ax.text(
            1,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="right",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes,
        )
    else:
        ax.text(
            0,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="left",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes
            + mpl.transforms.ScaledTranslation(
                +2 / 72,
                0 / 72,
                ax.figure.dpi_scale_trans,
            ),
        )

    # ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: f"{x:.0%}"))

    ax.axhline(0, color="k", ls="--", zorder=-1)
    ax.set_xlabel("Spatial Tuning Quartile")
    ax.set_ylabel("CSI")
    ax.set_ylim(-1, 1.45)

## Gamma heat map

In [34]:
def gamma_heatmap_scatter_resubmit(
    single_srf_df, twostim_ccg_df, twostim_gamma_power_df, axs=None, caxs=None
):
    srf_df = single_srf_df.copy()
    twostim_ccg_df = twostim_ccg_df.copy()
    gamma_df = twostim_gamma_power_df.copy()
    if axs == None:
        fig, axs = plt.subplots(1, 3)
    # axs.flatten()

    twostim_ccg_df = twostim_ccg_df[twostim_ccg_df["corr_type"] == "cross_region"]
    srf_corr_df = []

    for index in np.unique(twostim_ccg_df.index):
        try:
            date = index[0]
            owl = index[1]
            channel_ot = index[-2]
            channel_fore = index[-1]
            channel_ot_srf = srf_df.loc[date, owl, channel_ot]
            channel_fore_srf = srf_df.loc[date, owl, channel_fore]
            srf_corr = np.corrcoef(
                channel_ot_srf["norm_resp"], channel_fore_srf["norm_resp"]
            )[0][1]

            tmp = {
                "date": date,
                "owl": owl,
                "channel1": channel_ot,
                "channel2": channel_fore,
                "srf_corr": srf_corr,
            }
            srf_corr_df.append(tmp)

        except:
            pass

    srf_corr_df = pd.DataFrame(srf_corr_df)
    srf_corr_df = srf_corr_df.set_index(["date", "owl", "channel1", "channel2"])
    srf_corr_df = srf_corr_df.filter(
        [
            "srf_corr",
        ],
        axis=1,
    )

    merge = twostim_ccg_df.join(srf_corr_df, how="inner", rsuffix="_srf")

    merge = merge.join(
        gamma_df,
        on=["date", "owl", "channel1"],
        how="inner",
        rsuffix="1",
    )

    merge = merge.join(
        gamma_df,
        on=["date", "owl", "channel2"],
        how="inner",
        rsuffix="2",
    )

    merge = merge.loc[
        (merge["fixedintensity"] == merge["fixedintensity1"])
        & (merge["fixedintensity"] == merge["fixedintensity2"])
        & (merge["fixedintensity1"] == merge["fixedintensity2"])
        & (merge["varyingintensity"] == merge["varyingintensity1"])
        & (merge["varyingintensity"] == merge["varyingintensity2"])
        & (merge["varyingintensity1"] == merge["varyingintensity2"])
        & (merge["fixedazi"] == merge["fixedazi1"])
        & (merge["fixedazi"] == merge["fixedazi2"])
        & (merge["fixedazi1"] == merge["fixedazi2"])
        & (merge["fixedele"] == merge["fixedele1"])
        & (merge["fixedele"] == merge["fixedele2"])
        & (merge["fixedele1"] == merge["fixedele2"])
    ]

    merge["mean_gamma_power"] = (merge["gammapower"] + merge["gammapower2"]) / 2
    merge["mean_gamma_plv"] = (merge["gamma_plv"] + merge["gamma_plv2"]) / 2
    merge["diff_gamma_angle"] = np.abs(
        merge["gamma_plv_angle"] - merge["gamma_plv_angle2"]
    )

    quantiles = np.quantile(merge["srf_corr"], [0.25, 0.50, 0.75])
    low_corr = merge[(merge["srf_corr"] <= quantiles[0])].set_index("relative_level")
    lower_mid_corr = merge[
        (merge["srf_corr"] > quantiles[0]) & ((merge["srf_corr"] <= quantiles[1]))
    ].set_index("relative_level")

    higher_mid_corr = merge[
        (merge["srf_corr"] > quantiles[1]) & ((merge["srf_corr"] <= quantiles[2]))
    ].set_index("relative_level")

    high_corr = merge[(merge["srf_corr"] > quantiles[2])].set_index("relative_level")

    merge = merge.set_index("relative_level")

    psth_len = (merge.iloc[0]["ccg"].size + 1) / 2

    lags = scipy.signal.correlation_lags(psth_len, psth_len) / 1000
    lags_mask = np.abs(lags) <= 0.95
    lags_mask_std = np.abs(lags) >= 0.825
    relative_levels = np.unique(merge.index)
    groupdata = [low_corr, lower_mid_corr, higher_mid_corr, high_corr]

    gamma_xcorr_df = {
        "mean_gamma_power": [],
        "gamma_power_ot": [],
        "gamma_power_forebrain": [],
        "xregion_corr": [],
        "relative_level": [],
    }

    gamma_heatmap = np.empty([len(groupdata), len(relative_levels)])
    gamma_heatmap_ot = np.empty([len(groupdata), len(relative_levels)])
    gamma_heatmap_fb = np.empty([len(groupdata), len(relative_levels)])
    xregion_synchrony = np.empty([len(groupdata), len(relative_levels)])

    for ndata, data in enumerate(groupdata):
        for k, relative_level in enumerate(relative_levels):
            mean_ccg = np.mean(np.vstack(data.loc[relative_level, "ccg"]), axis=0)[::-1]
            mean_condition_gamma = np.mean(
                np.vstack(data.loc[relative_level, "mean_gamma_power"])
            )
            mean_gamma_ot = np.mean(np.vstack(data.loc[relative_level, "gammapower"]))
            mean_gamma_fb = np.mean(np.vstack(data.loc[relative_level, "gammapower2"]))
            mean_srf_corr = np.mean(np.vstack(data.loc[relative_level, "srf_corr"]))
            peak_properties = get_peak(mean_ccg, lags)
            gamma_heatmap[ndata, k] = mean_condition_gamma
            gamma_heatmap_ot[ndata, k] = mean_gamma_ot
            gamma_heatmap_fb[ndata, k] = mean_gamma_fb
            xregion_synchrony[ndata, k] = peak_properties["peak_corr"]
            gamma_xcorr_df["mean_gamma_power"].append(mean_condition_gamma)
            gamma_xcorr_df["gamma_power_ot"].append(mean_gamma_ot)
            gamma_xcorr_df["gamma_power_forebrain"].append(mean_gamma_fb)
            gamma_xcorr_df["xregion_corr"].append(peak_properties["peak_corr"])
            gamma_xcorr_df["relative_level"].append(relative_level)

    ot_corr_val = np.corrcoef(
        gamma_xcorr_df["gamma_power_ot"], gamma_xcorr_df["xregion_corr"]
    )[0][1]

    nrt_corr_val = np.corrcoef(
        gamma_xcorr_df["gamma_power_forebrain"], gamma_xcorr_df["xregion_corr"]
    )[0][1]

    xregion_corr_val = np.corrcoef(
        gamma_xcorr_df["mean_gamma_power"], gamma_xcorr_df["xregion_corr"]
    )[0][1]

    sc3 = axs[3].scatter(
        gamma_xcorr_df["gamma_power_ot"],
        gamma_xcorr_df["xregion_corr"],
        label=f"R = {ot_corr_val:.2f}",
        color="k",
    )

    axs[3].set_title("Within OT")
    axs[3].text(
        0.05,
        0.92,
        f"R = {ot_corr_val:.2f}",
        ha="left",
        va="top",
        fontsize=12,
        transform=axs[3].transAxes,
    )
    axs[3].set_xlabel("Mean OT gamma power")
    axs[3].set_ylabel("X-Region Synchrony")
    axs[3].spines["top"].set_visible(False)
    axs[3].spines["right"].set_visible(False)

    sc4 = axs[4].scatter(
        gamma_xcorr_df["gamma_power_forebrain"],
        gamma_xcorr_df["xregion_corr"],
        label=f"R = {nrt_corr_val:.2f}",
        color="k",
    )

    axs[4].set_title("Within nRt")
    axs[4].text(
        0.05,
        0.92,
        f"R = {nrt_corr_val:.2f}",
        ha="left",
        va="top",
        fontsize=12,
        transform=axs[4].transAxes,
    )
    axs[4].set_xlabel("Mean nRt gamma power")
    axs[4].set_ylabel("X-Region Synchrony")
    axs[4].spines["top"].set_visible(False)
    axs[4].spines["right"].set_visible(False)

    sc5 = axs[5].scatter(
        gamma_xcorr_df["mean_gamma_power"],
        gamma_xcorr_df["xregion_corr"],
        label=f"R = {xregion_corr_val:.2f}",
        color="k",
    )
    axs[5].set_title("Cross Region")
    axs[5].text(
        0.05,
        0.92,
        f"R = {xregion_corr_val:.2f}",
        ha="left",
        va="top",
        fontsize=12,
        transform=axs[5].transAxes,
    )
    axs[5].set_xlabel("Mean X-region gamma power")
    axs[5].set_ylabel("X-Region Synchrony")
    axs[5].spines["top"].set_visible(False)
    axs[5].spines["right"].set_visible(False)

    # plt.figure()
    #    hm = axs.imshow(gamma_heatmap, cmap = 'viridis')

    hm = axs[0].imshow(
        gamma_heatmap_ot,
        cmap="viridis",
        vmin=np.min(gamma_heatmap_ot),
        vmax=np.max(gamma_heatmap_ot),
    )

    axs[0].set_xticks(np.arange(len(relative_levels)), relative_levels)
    axs[0].set_yticks(
        np.arange(len(groupdata)), ["<25th %", "25-50th %", "50-75th %", ">75th %"]
    )
    cb = plt.colorbar(hm, cax=caxs[0]).set_label("Mean OT gamma power", size=10)

    axs[0].set_xlabel("Relative level [dB]")
    hm2 = axs[1].imshow(
        gamma_heatmap_fb,
        cmap="viridis",
        vmin=np.min(gamma_heatmap_ot),
        vmax=np.max(gamma_heatmap_ot),
    )

    axs[1].set_xticks(np.arange(len(relative_levels)), relative_levels)
    axs[1].set_yticks(
        np.arange(len(groupdata)), ["<25th %", "25-50th %", "50-75th %", ">75th %"]
    )

    cb2 = plt.colorbar(hm2, cax=caxs[1]).set_label("Mean nRt gamma power", size=10)

    axs[1].set_xlabel("Relative level [dB]")

    hm3 = axs[2].imshow(
        gamma_heatmap,
        cmap="viridis",
        vmin=np.min(gamma_heatmap_ot),
        vmax=np.max(gamma_heatmap_ot),
    )

    axs[2].set_xticks(np.arange(len(relative_levels)), relative_levels)
    axs[2].set_yticks(
        np.arange(len(groupdata)), ["<25th %", "25-50th %", "50-75th %", ">75th %"]
    )
    cb2 = plt.colorbar(hm2, cax=caxs[2]).set_label("Mean X-region gamma power", size=9)
    axs[2].set_xlabel("Relative level [dB]")

## Spike Field Coherence

In [35]:
def within_area_sfc_competition(within_area_sfc_df, axs=None, colors=c_ot):
    if axs is None:
        fig, axs = plt.subplots(1, 3)

    axs.flatten()

    lfp_df = within_area_sfc_df.copy()

    lfp_df = lfp_df.set_index("relative_level")

    frequency = lfp_df.iloc[0]["frequency"]

    level_plots = [-15, 0, 10]
    level_plot_label = ["-15", "0", "+10"]

    for kk, level in enumerate(level_plots):

        sub = np.mean(np.vstack(lfp_df.loc[level, "spike_field_coherence"]), axis=0)
        sub_sem = scipy.stats.sem(
            np.vstack(lfp_df.loc[level, "spike_field_coherence"]), axis=0
        )
        shuffle = np.mean(np.vstack(lfp_df.loc[level, "trial_shuffled_sfc"]), axis=0)
        shuffle_sem = scipy.stats.sem(
            np.vstack(lfp_df.loc[level, "trial_shuffled_sfc"]), axis=0
        )

        axs[kk].set_title(level_plot_label[kk])
        axs[kk].errorbar(frequency, sub, sub_sem, color=colors[0])

        shuff = axs[kk].errorbar(
            frequency, shuffle, shuffle_sem, color="0.4", label="trial shuffled"
        )

        axs[kk].set_xlim(0, 100)
        if kk == 0:
            axs[kk].set_xticks([0, 25, 50, 75], [0, 25, 50, 75])
            axs[kk].set_xlabel("Frequency (Hz)")
            axs[kk].set_ylabel("Spike Field\nCoherence")
        else:
            axs[kk].set_xticks([0, 25, 50, 75], [])
            axs[kk].set_yticklabels([])
        axs[kk].set_ylim(bottom=0.05, top=0.450)

## Preferred Phase Analysis

In [36]:
def figure_withinregion_spike_phase(
    cross_region_df,
    colors="b",
    axs=None,
):
    if axs is None:
        fig, axs = plt.subplots(1, 3, sharey=True)

    gamma_df = cross_region_df.copy()

    print("Number of units:", np.unique(gamma_df.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in gamma_df.index.values]))),
    )

    gamma_df = gamma_df.set_index("relative_level")
    relative_levels = np.unique(gamma_df.index)

    # colors = ['slategrey', 'royalblue','blue','navy']

    phase_bins = np.linspace(-np.pi, np.pi, 9)

    levels_plot = [-15, 0, 10]
    for k, relative_level in enumerate(levels_plot):
        spike_angles = gamma_df.loc[relative_level, "gamma_spike_angles"].values
        mean_phases = gamma_df.loc[relative_level, "gamma_plv_angle"].values

        mean_vs = np.mean(gamma_df.loc[relative_level, "gamma_plv"].values)
        norm_trace = []
        for angles, mean_ang in zip(spike_angles, mean_phases):
            transform = angles - mean_ang
            transform = wrap_to_pi(transform)
            zz = np.histogram(transform, phase_bins)
            norm_zz = zz[0] / sum(zz[0])
            norm_trace.append(norm_zz)
        mean_trace = np.mean(norm_trace, axis=0)
        sem_trace = scipy.stats.sem(norm_trace, axis=0)

        spike_angles_shuff = gamma_df.loc[
            relative_level, "gamma_spike_angles_shuffled"
        ].values
        mean_phases_shuff = gamma_df.loc[
            relative_level, "gamma_plv_angle_shuffled"
        ].values
        mean_vs_shuff = np.mean(
            gamma_df.loc[relative_level, "gamma_plv_shuffled"].values
        )
        norm_trace_shuff = []
        for angles, mean_ang in zip(spike_angles_shuff, mean_phases_shuff):
            transform = angles - mean_ang
            transform = wrap_to_pi(transform)
            zz = np.histogram(transform, phase_bins)
            norm_zz = zz[0] / sum(zz[0])
            norm_trace_shuff.append(norm_zz)
        mean_trace_shuff = np.mean(norm_trace_shuff, axis=0)
        sem_trace_shuff = scipy.stats.sem(norm_trace_shuff, axis=0)

        # axs[ndata, k].plot(bins[:-1], histogram, color=colors[ndata], ls="-", lw=1)
        # print(histogram)
        axs[k].errorbar(
            phase_bins[:-1],
            mean_trace_shuff,
            sem_trace_shuff,
            color="0.4",
            ls="-",
            lw=1,
        )
        axs[k].errorbar(
            phase_bins[:-1], mean_trace, sem_trace, color=colors, ls="-", lw=1
        )

        axs[k].text(
            0.5,
            0.80,
            f"VS {mean_vs:.3f}",
            transform=axs[k].transAxes
            + mpl.transforms.ScaledTranslation(
                0 / 72,
                +4 / 72,
                axs[k].figure.dpi_scale_trans,
            ),
            ha="center",
            fontsize=10,
        )
        axs[k].set_ylim(bottom=0.0, top=0.50)
        axs[k].spines["top"].set_visible(False)
        axs[k].spines["right"].set_visible(False)
        axs[k].set_xlim(-np.pi, +np.pi)
        axs[k].set_xticks(
            [-np.pi, 0, +np.pi], labels=["-π", "0", "π"] if k == 0 else []
        )
        axs[k].set_xlabel("φ - Best φ")
        if k != 0:
            axs[k].set_yticks([])
            axs[k].set_xlabel(None)

    colors = ["r", "orange", "k", "blue", "purple", "violet"]
    for k, relative_level in enumerate(relative_levels):
        spike_angles = gamma_df.loc[relative_level, "gamma_spike_angles"].values
        mean_phases = gamma_df.loc[relative_level, "gamma_plv_angle"].values
        prop = phase_properties(mean_phases)

# Figures 

## Figure 1 Experimental Design

In [None]:
fig = plt.figure(figsize=(3.30, 2.4 + 2.10 + 2.4 + 1.25))

ax1a = figure_add_axes_inch(
    fig,
    top=0.05,
    left=0.8,
    width=2.1,
    height=2.4,
    label="A",
)
plt.setp(ax1a, frame_on=False, xticks=[], yticks=[], zorder=20)
owl_image = matplotlib.image.imread(
    r".\other_figures\experimental_design_forebrainmanuscript.png"
)
ax1a.imshow(owl_image)

subplot_indicator(ax1a, ha="left", va="top", pad_inch=0.7)


ax1b = figure_add_axes_inch(
    fig,
    top=0.3 + 2.4 + 0.15,
    left=0.8,
    width=1.9,
    height=2.0,
    label="B",
)

ax1b_cb = figure_add_axes_inch(
    fig,
    top=0.3 + 2.4 + 0.15,
    left=0.8 + 1.9 + 0.05,
    width=0.05,
    height=2.0,
)

ax1b.set_title("Example OT unit")

figure_spatial_receptive_field(midhighcorr_srf_ot, ax=ax1b, cax=ax1b_cb)
subplot_indicator(ax1b, ha="left", va="top", pad_inch=0.7)


ax1c = figure_add_axes_inch(
    fig,
    bottom=0.5,
    left=0.8,
    width=1.9,
    height=2.0,
    label="C",
)

ax1c_cb = figure_add_axes_inch(
    fig,
    bottom=0.5,
    left=0.8 + 1.9 + 0.05,
    width=0.05,
    height=2.0,
)

ax1c.set_title("Example nRt unit")

figure_spatial_receptive_field(midhighcorr_srf_fb, ax=ax1c, cax=ax1c_cb)
subplot_indicator(ax1c, ha="left", va="top", pad_inch=0.7)

fig.savefig(OUTDIR / "figure1.pdf")
fig.savefig(OUTDIR / "figure1.png")
fig.savefig(OUTDIR / "figure1.eps")

## Figure 2 Azimuth and Elevation tuning nRt

In [None]:
single_srf_df = single_srf[single_srf["region"] == "Forebrain"]
single_rlf_df = single_rlf[single_rlf["region"] == "Forebrain"]

fig = plt.figure(figsize=(6.80, 6.40))
ax2a = figure_add_axes_inch(
    fig,
    top=0.3,
    left=0.6,
    width=2.1,
    height=3.0,
    label="A",
)
ax2a.set_title("nRt Azimuth Tuning")

ax2a_cb = figure_add_axes_inch(
    fig,
    top=0.3,
    left=0.6 + 2.1 + 0.05,
    width=0.05,
    height=3.0,
)

ax2b = figure_add_axes_inch(
    fig,
    top=0.3 + 3.0 + 0.6,
    left=0.7,
    width=2.1,
    height=2.0,
    label="B",
)

azimuth_heatplot_alt(single_srf_df, axs=[ax2a, ax2b], caxs=ax2a_cb)

ax2c = figure_add_axes_inch(
    fig,
    top=0.3,
    left=0.5 + 2.4 + 1.2,
    width=2.1,
    height=3.0,
    label="C",
)

ax2c_cb = figure_add_axes_inch(
    fig,
    top=0.3,
    left=0.5 + 2.4 + 2.1 + 1.2 + 0.05,
    width=0.05,
    height=3.0,
)
ax2c.set_title("nRt Elevation Tuning")

ax2d = figure_add_axes_inch(
    fig,
    top=0.3 + 3.0 + 0.6,
    left=0.6 + 2.4 + 1.2,
    width=2.1,
    height=2.0,
    label="D",
)

elevation_heatplot_alt(single_srf_df, axs=[ax2c, ax2d], caxs=ax2c_cb)

subplot_indicator(ax2a, ha="left", va="bottom", pad_inch=0.6)
subplot_indicator(ax2b, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax2c, ha="left", va="bottom", pad_inch=0.6)
subplot_indicator(ax2d, ha="left", va="bottom", pad_inch=0.7)
# fig.outline()

fig.savefig(OUTDIR / "figure2.pdf")
fig.savefig(OUTDIR / "figure2.png")
fig.savefig(OUTDIR / "figure2.eps")

## Figure 3 nRt flat noise competition

In [None]:
## forebrain specific data
forebrain_twostim = twostim_rlf[(twostim_rlf["region"] == "Forebrain")]
forebrain_singlestim = single_rlf[(single_rlf["region"] == "Forebrain")]

twostim_ccg_forebrain = twostim_ccg[(twostim_ccg["corr_type"] == "within_Forebrain")]
single_ccg_forebrain = single_ccg[(single_ccg["corr_type"] == "within_Forebrain")]


fig = plt.figure(figsize=(3.30, 5.10 + 2.6 + 0.1))
ax3a = figure_add_axes_inch(
    fig,
    top=0.1,
    left=0.8,
    width=2.4,
    height=1.5,
    label="A",
)
figure_competition_boxplot(
    forebrain_twostim,
    forebrain_singlestim,
    ax=ax3a,
    brackets={(0, 5): 0.9, (1, 5): 0.8, (2, 5): 0.7},
    colors=c_forebrain,
)
ax3a.set_ylabel("Change in\nSpike Rate")
ax3a.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax3a, ha="left", va="top", pad_inch=0.7)

axs3b, axg3b = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=2,
    group_top=2.4,
    group_left=0.8,
    individual_width=(2.4 - 0.1) / 2,
    individual_height=2.0,
    wspace=0.1,
    hspace=0.0,
)

figure_coincident_rasterplot(
    spiketrain_data=example_spiketrains_forebrain_flat,
    axs=axs3b,
)
axs3b[0, 0].set_title("-15 dB")
axs3b[0, 1].set_title("+10 dB")
subplot_indicator(axg3b, "B", ha="left", va="top", pad_inch=0.7)


axs3c, axg3c = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=6,
    group_top=2.2 + 2.6 + 0.1,
    group_left=0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)

figure_xcorr_ccg(
    twostim_ccg_forebrain,
    single_ccg_forebrain,
    axs=axs3c.flatten()[:],
    colors=c_forebrain,
)


for ax in axs3c[0, :]:
    ax.set_ylim(top=1.0e-5)
axs3c[0, 0].set_yticks(np.array([0.5]) * 1e-5, ["0.5"], minor=False)
axs3c[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)

subplot_indicator(axg3c, "C", ha="left", va="top", pad_inch=0.7)


ax3d = figure_add_axes_inch(
    fig,
    left=0.8,
    width=2.4,
    bottom=0.5,
    height=1.5,
    label="D",
)
figure_xcorr_boxplot(
    twostim_ccg_forebrain,
    single_ccg_forebrain,
    ax=ax3d,
    colors=c_forebrain,
    brackets={},
    # brackets={(1,5): 0.9, (2,5):0.8},
)
ax3d.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(ax3d, ha="left", va="top", pad_inch=0.7)
ax3d.set_ylim(-0.75, 0.75)

# condition_batch(fig, 0.8, 0.3, text="flat noise", fontsize=10, color='b', y_pt=2, x_pt=0, pad_pt=2)
condition_batch(
    fig,
    left=0.8 + 0.2,
    top=2.3,
    text="flat noise",
    color=c_forebrain[0],
    fontsize=12,
    ha="right",
)

fig.savefig(OUTDIR / "figure3.pdf")
fig.savefig(OUTDIR / "figure3.png")
fig.savefig(OUTDIR / "figure3.eps")

# fig.outline()
fig.get_size_inches()

## Figure 4 Tuning Selectivity Index + Competition

In [None]:
forebrain_twostim = twostim_rlf[twostim_rlf["region"] == "Forebrain"]
forebrain_singlestim = single_rlf[single_rlf["region"] == "Forebrain"]
forebrain_singlestim_out = single_rlf_out[single_rlf_out["region"] == "Forebrain"]


forebrain_single_ccg = single_ccg[single_ccg["corr_type"] == "within_Forebrain"]
forebrain_ts_ccg = twostim_ccg[twostim_ccg["corr_type"] == "within_Forebrain"]

fig = plt.figure(figsize=(6.80, 4.5))
ax4a = figure_add_axes_inch(
    fig,
    top=0.1,
    left=0.8,
    width=5.6,
    height=1.0,
    label="A",
)


histogram_selectivity_resubmit(
    forebrain_twostim,
    forebrain_singlestim,
    forebrain_singlestim_out,
    colors=c_forebrain,
    ax=ax4a,
)

subplot_indicator(ax4a, ha="left", va="top", pad_inch=0.75)
# subplot_indicator(ax4a, ha="left", va="top", pad_inch=0.75)

ax4b = figure_add_axes_inch(
    fig,
    top=1.5 + 0.3 + 0.2,
    left=0.8,
    width=2.4,
    height=2.0,
    label="B",
)

ax4c = figure_add_axes_inch(
    fig,
    top=1.5 + 0.3 + 0.2,
    left=0.8 + 2.4 + 0.8,
    width=2.4,
    height=2.0,
    label="C",
)

figure_singlesound_rlf_selectivity_resubmit(
    forebrain_singlestim, forebrain_singlestim_out, ax4b
)
subplot_indicator(ax4b, ha="left", va="top", pad_inch=0.75)

# ax2e = figure_add_axes_inch(
#     fig,
#     bottom = 0.50,
#     left=0.8,
#     width=5.6,
#     height=2.0,
#     label="D",
# )

figure_rlf_selectivity_resubmit(
    forebrain_twostim,
    forebrain_singlestim,
    forebrain_singlestim_out,
    colors=c_forebrain,
    ax=ax4c,
)
ax4c.set_ylim(-0.75, 0.75)
ax4c.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
ax4c.set_title("Flat Noise Competition")
subplot_indicator(ax4c, ha="left", va="top", pad_inch=0.75)

# fig.outline()
fig.savefig(OUTDIR / "figure4.pdf")
fig.savefig(OUTDIR / "figure4.png")
fig.savefig(OUTDIR / "figure4.eps")

## Figure 5 Modeling Figure

In [None]:
forebrain_singlestim = single_rlf[single_rlf["region"] == "Forebrain"]
forebrain_singlestim_out = single_rlf_out[single_rlf_out["region"] == "Forebrain"]

sound_level, rate_driver_4b, rate_competitor_4b = rlf_modeling_data(
    forebrain_singlestim, forebrain_singlestim_out
)

forebrain_twostim = twostim_rlf[twostim_rlf["region"] == "Forebrain"]
forebrain_singlestim = single_rlf[single_rlf["region"] == "Forebrain"]
forebrain_singlestim_out = single_rlf_out[single_rlf_out["region"] == "Forebrain"]


relative_level, relative_response_d, relative_response_c = competition_modeling_data(
    forebrain_twostim, forebrain_singlestim, forebrain_singlestim_out
)


fig = plt.figure(figsize=(6.80, 4.90))

ax5a = figure_add_axes_inch(fig, top=0.3, left=0.8, width=2.4, height=1.5, label="A")
ax5b = figure_add_axes_inch(
    fig,
    top=0.3,
    left=0.8 + 2.4 + 1.1,
    width=2.4,
    height=1.5,
    label="B",
)

model(relative_response_d, axs=[ax5a, ax5b])
ax5b.set_title("Invariant nRt Response")
ax5a.set_title("OT responses")

ax5c = figure_add_axes_inch(
    fig, top=0.3 + 1.5 + 1.0, left=0.8, width=2.4, height=1.5, label="C"
)
ax5d = figure_add_axes_inch(
    fig,
    top=0.3 + 1.5 + 1.0,
    left=0.8 + 2.4 + 1.1,
    width=2.4,
    height=1.5,
    label="D",
)
model(relative_response_c, axs=[ax5c, ax5d])
ax5d.set_title("Increasing nRt Response")
ax5c.set_title("OT responses")


subplot_indicator(ax5a, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax5b, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax5c, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax5d, ha="left", va="bottom", pad_inch=0.7)


fig.savefig(OUTDIR / "figure5.pdf")
fig.savefig(OUTDIR / "figure5.png")
fig.savefig(OUTDIR / "figure5.eps")

## Figure 6 Extreme Selectivity Resubmission

In [None]:
forebrain_twostim = twostim_rlf[twostim_rlf["region"] == "Forebrain"]
forebrain_singlestim = single_rlf[single_rlf["region"] == "Forebrain"]
forebrain_singlestim_out = single_rlf_out[single_rlf_out["region"] == "Forebrain"]

fig = plt.figure(figsize=(6.80, 2.5))
ax6a = figure_add_axes_inch(fig, top=0.5, left=0.8, width=2.4, height=1.5, label="A")
ax6a.set_title("Extreme Frontal\nDriver Selectivity")
ax6b = figure_add_axes_inch(
    fig,
    top=0.5,
    left=0.8 + 2.4 + 1.1,
    width=2.4,
    height=1.5,
    label="B",
)
ax6b.set_title("Extreme Lateral\nDriver Selectivity")

extreme_selectivity_competition_resubmit(
    forebrain_twostim, forebrain_singlestim, forebrain_singlestim_out, axs=[ax6a, ax6b]
)
# fig.outline()
subplot_indicator(ax6a, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax6b, ha="left", va="bottom", pad_inch=0.7)

fig.savefig(OUTDIR / "figure6.pdf")
fig.savefig(OUTDIR / "figure6.png")
fig.savefig(OUTDIR / "figure6.eps")

## Figure 7 Frontal OT vs. Peripheral OT competition

In [None]:
frontal_ot_twostim = twostim_rlf[
    (twostim_rlf["region"] == "OT") & (abs(twostim_rlf["fixedazi"]) <= 45)
]
frontal_ot_singlestim = single_rlf[
    (single_rlf["region"] == "OT") & (abs(single_rlf["azimuth"]) <= 45)
]
peripheral_ot_twostim = twostim_rlf[
    (twostim_rlf["region"] == "OT") & (abs(twostim_rlf["fixedazi"]) > 45)
]
peripheral_ot_singlestim = single_rlf[
    (single_rlf["region"] == "OT") & (abs(single_rlf["azimuth"]) > 45)
]

fig = plt.figure(figsize=(6.80, 2.30))
ax7a = figure_add_axes_inch(fig, top=0.3, left=0.9, width=2.4, height=1.5, label="A")
ax7b = figure_add_axes_inch(
    fig, top=0.3, left=0.8 + 2.4 + 1.1, width=2.4, height=1.5, label="B"
)

figure_competition_boxplot(
    frontal_ot_twostim, frontal_ot_singlestim, colors=["b"] * 6, ax=ax7a
)
ax7a.set_title("Frontal OT")

figure_competition_boxplot(
    peripheral_ot_twostim, peripheral_ot_singlestim, colors=["b"] * 6, ax=ax7b
)
ax7b.set_title("Lateral OT")

subplot_indicator(ax7a, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax7b, ha="left", va="bottom", pad_inch=0.7)

fig.savefig(OUTDIR / "figure7.pdf")
fig.savefig(OUTDIR / "figure7.png")
fig.savefig(OUTDIR / "figure7.eps")

# fig.outline()

## Figure 8 AM noise competition

In [None]:
## Forebrain Driver 55Hz, Competitor 75Hz
am_single_ccg_55 = am_single_ccg[
    (am_single_ccg["corr_type"] == "within_Forebrain")
    & (am_single_ccg["modulation_frequency"] == 55)
]
am_single_rlf_55 = am_single_rlf[
    (am_single_rlf["region"] == "Forebrain")
    & (am_single_rlf["modulation_frequency"] == 55)
]

am_twostim_ccg_55 = am_twostim_ccg[
    (am_twostim_ccg["fixed_modulation_frequency"] == 55)
    & (am_twostim_ccg["corr_type"] == "within_Forebrain")
]
am_twostim_rlf_55 = am_twostim_rlf[
    (am_twostim_rlf["fixed_modulation_frequency"] == 55)
    & (am_twostim_rlf["region"] == "Forebrain")
]

## Forebrain Driver 75Hz, Competitor 55Hz
am_single_ccg_75 = am_single_ccg[
    (am_single_ccg["corr_type"] == "within_Forebrain")
    & (am_single_ccg["modulation_frequency"] == 75)
]
am_single_rlf_75 = am_single_rlf[
    (am_single_rlf["region"] == "Forebrain")
    & (am_single_rlf["modulation_frequency"] == 75)
]

am_twostim_ccg_75 = am_twostim_ccg[
    (am_twostim_ccg["fixed_modulation_frequency"] == 75)
    & (am_twostim_ccg["corr_type"] == "within_Forebrain")
]
am_twostim_rlf_75 = am_twostim_rlf[
    (am_twostim_rlf["fixed_modulation_frequency"] == 75)
    & (am_twostim_rlf["region"] == "Forebrain")
]


fig = plt.figure(figsize=(6.80, 5.10 + 2.6 + 0.1))

ax8a = figure_add_axes_inch(fig, top=0.1, left=0.8, width=2.4, height=1.5, label="A")
figure_competition_boxplot(
    am_twostim_rlf_55,
    am_single_rlf_55,
    ax=ax8a,
    colors=c_forebrain_am55,
    brackets={(0, 5): 0.9, (1, 5): 0.8, (2, 5): 0.7},
)
ax8a.set_ylabel("Change in\nSpike Rate")
ax8a.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax8a, ha="left", va="top", pad_inch=0.7)


ax8b = figure_add_axes_inch(
    fig,
    top=0.1,
    left=0.8 + 2.4 + 0.3 + 0.8,
    width=2.4,
    height=1.5,
    label="B",
)
figure_competition_boxplot(
    am_twostim_rlf_75,
    am_single_rlf_75,
    ax=ax8b,
    colors=c_forebrain_am75,
    brackets={(0, 5): 0.9, (1, 5): 0.8, (2, 5): 0.7},
)

ax8b.set_ylabel("Change in\nSpike Rate")
ax8b.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax8b, ha="left", va="top", pad_inch=0.7)


axs8c, axg8c = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=2,
    group_top=2.4,
    group_left=0.8,
    individual_width=(2.4 - 0.1) / 2,
    individual_height=2.0,
    wspace=0.1,
    hspace=0.0,
)

figure_coincident_rasterplot(
    spiketrain_data=example_spiketrains_am55,
    axs=axs8c,
)
axs8c[0, 0].set_title("-15 dB")
axs8c[0, 1].set_title("+10 dB")
subplot_indicator(axg8c, "C", ha="left", va="top", pad_inch=0.7)


axs8d, axg8d = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=2,
    group_top=2.4,
    group_left=0.8 + 2.4 + 0.3 + 0.8,
    individual_width=(2.4 - 0.1) / 2,
    individual_height=2.0,
    wspace=0.1,
    hspace=0.0,
)

figure_coincident_rasterplot(
    spiketrain_data=example_spiketrains_am75,
    axs=axs8d,
)
axs8d[0, 0].set_title("-15 dB")
axs8d[0, 1].set_title("+10 dB")
subplot_indicator(axg8d, "D", ha="left", va="top", pad_inch=0.7)


axs8e, axg8e = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=6,
    group_top=2.2 + 2.6 + 0.1,
    group_left=0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)
figure_xcorr_ccg(
    am_twostim_ccg_55, am_single_ccg_55, axs=axs8e.flatten()[:], colors=c_forebrain_am55
)
for ax in axs8e[0, :]:
    ax.set_ylim(top=1.5e-5)
axs8e[0, 0].set_yticks(np.array([0, 1]) * 1e-5, minor=False)
axs8e[0, 0].set_yticks(np.array([0.5]) * 1e-5, minor=True)
axs8e[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)
subplot_indicator(axg8e, "E", ha="left", va="top", pad_inch=0.7)


axs8f, axg8f = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=6,
    group_top=2.2 + 2.6 + 0.1,
    group_left=0.8 + 2.4 + 0.3 + 0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)
figure_xcorr_ccg(
    am_twostim_ccg_75, am_single_ccg_75, axs=axs8f.flatten()[:], colors=c_forebrain_am75
)
for ax in axs8f[0, :]:
    ax.set_ylim(top=1.5e-5)
axs8f[0, 0].set_yticks(np.array([0, 1]) * 1e-5, minor=False)
axs8f[0, 0].set_yticks(np.array([0.5]) * 1e-5, minor=True)
axs8f[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)
subplot_indicator(axg8f, "F", ha="left", va="top", pad_inch=0.7)

ax8g = figure_add_axes_inch(
    fig,
    left=0.8,
    width=2.4,
    bottom=0.5,
    height=1.5,
    label="G",
)
figure_xcorr_boxplot(
    am_twostim_ccg_55, am_single_ccg_55, ax=ax8g, colors=c_forebrain_am55, brackets={}
)

ax8g.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(ax8g, ha="left", va="top", pad_inch=0.7)

ax8h = figure_add_axes_inch(
    fig,
    left=0.8 + 2.4 + 0.3 + 0.8,
    width=2.4,
    bottom=0.5,
    height=1.5,
    label="H",
)
figure_xcorr_boxplot(
    am_twostim_ccg_75, am_single_ccg_75, ax=ax8h, colors=c_forebrain_am75, brackets={}
)
ax8h.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(ax8h, ha="left", va="top", pad_inch=0.7)

condition_batch(
    fig,
    left=0.8,
    top=2.3,
    text="Driver\n55 Hz",
    color=c_forebrain_am55[0],
    fontsize=12,
    ha="right",
)
condition_batch(
    fig,
    left=0.8 + 2.4 + 0.3 + 0.8,
    top=2.3,
    text="Driver\n75 Hz",
    color=c_forebrain_am75[0],
    fontsize=12,
    ha="right",
)

# fig.savefig(OUTDIR / "figure3.pdf")
# fig.savefig(OUTDIR / "forebrain_am_competition.png")
# fig.savefig(OUTDIR / "figure3.eps")

# fig.outline()
print(
    f"Figure size: {' x '.join(fig.get_size_inches().astype('str'))} inches == {' x '.join(f'{l*2.54:.2f}' for l in fig.get_size_inches())} cm"
)

fig.savefig(OUTDIR / "figure8.pdf")
fig.savefig(OUTDIR / "figure8.png")
fig.savefig(OUTDIR / "figure8.eps")

## Figure 9 AM noise Switch

In [None]:
## Forebrain Driver 55Hz, Competitor 75Hz
am_single_ccg_55 = am_single_ccg[
    (am_single_ccg["corr_type"] == "within_Forebrain")
    & (am_single_ccg["modulation_frequency"] == 55)
]
am_single_rlf_55 = am_single_rlf[
    (am_single_rlf["region"] == "Forebrain")
    & (am_single_rlf["modulation_frequency"] == 55)
]

am_twostim_ccg_55 = am_twostim_ccg[
    (am_twostim_ccg["fixed_modulation_frequency"] == 55)
    & (am_twostim_ccg["corr_type"] == "within_Forebrain")
]
am_twostim_rlf_55 = am_twostim_rlf[
    (am_twostim_rlf["fixed_modulation_frequency"] == 55)
    & (am_twostim_rlf["region"] == "Forebrain")
]

## Forebrain Driver 75Hz, Competitor 55Hz
am_single_ccg_75 = am_single_ccg[
    (am_single_ccg["corr_type"] == "within_Forebrain")
    & (am_single_ccg["modulation_frequency"] == 75)
]
am_single_rlf_75 = am_single_rlf[
    (am_single_rlf["region"] == "Forebrain")
    & (am_single_rlf["modulation_frequency"] == 75)
]

am_twostim_ccg_75 = am_twostim_ccg[
    (am_twostim_ccg["fixed_modulation_frequency"] == 75)
    & (am_twostim_ccg["corr_type"] == "within_Forebrain")
]
am_twostim_rlf_75 = am_twostim_rlf[
    (am_twostim_rlf["fixed_modulation_frequency"] == 75)
    & (am_twostim_rlf["region"] == "Forebrain")
]


df_twostim_55 = am_twostim_stim_phaselocking[
    (am_twostim_stim_phaselocking["fixed_modulation_frequency"] == 55)
    & (am_twostim_stim_phaselocking["region"] == "Forebrain")
]
df_twostim_75 = am_twostim_stim_phaselocking[
    (am_twostim_stim_phaselocking["fixed_modulation_frequency"] == 75)
    & (am_twostim_stim_phaselocking["region"] == "Forebrain")
]

fig = plt.figure(figsize=(6.80, 5.6))

axs9a, axg9a = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=6,
    group_top=0.3 + 0.15,
    group_left=0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.030,
    hspace=0.0,
)
figure_xcorr_ccg(
    am_twostim_ccg_55, am_single_ccg_55, axs=axs9a, colors=c_forebrain_am55
)
for ax in axs9a[0, :]:
    ax.set_ylim(bottom=-0.10e-5, top=1.0e-5)

axs9a[0, 0].set_yticks(np.array([0.5]) * 1e-5, [".5"], minor=False)
# axs3a[0, 0].set_yticks(np.array([0.5]) * 1e-5, minor=True)
axs9a[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)
subplot_indicator(axg9a, "A", ha="left", va="top", pad_inch=0.7)

axs9b, axg9b = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=6,
    group_top=0.3 + 1.0 + 0.15,
    group_left=0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.030,
    hspace=0.0,
)
figure_xcorr_pspectra(
    am_twostim_ccg_55, am_single_ccg_55, axs=axs9b, colors=c_forebrain_am55
)
subplot_indicator(axg9b, "B", ha="left", va="top", pad_inch=0.7)


axs9c, axg9c = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=6,
    group_top=0.4 + 1.0 + 0.15 + 0.6 + 0.4,
    group_left=0.8,  #  + 2.4 + 0.3 + 0.8
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.030,
    hspace=0.0,
)

figure_stimphase_driver_competitor(
    df_twostim_55, axs=axs9c[:], colors=[c_forebrain_am55, c_forebrain_am75]
)


for ax in axs9c[0, :]:
    ax.set_title("")
axg9c.set_title("")
axg9c.set_ylabel("VS Competitor")
axg9c.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(axg9c, "C", ha="left", va="bottom", pad_inch=0.7)

axs9c[0, 0].set_xlim(0, 0.5)
axs9c[0, 0].set_ylim(0, 0.5)

ax9d = figure_add_axes_inch(
    fig,
    bottom=0.5,
    left=0.8,
    width=2.55,
    height=1.5,
    label="D",
)
figure_stimphase_switch(
    df_twostim_55, ax=ax9d, colors=[c_forebrain_am55, c_forebrain_am75]
)
# ax3d.set_ylabel("Change in\nSpike Rate")
# ax3d.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax9d, ha="left", va="top", pad_inch=0.7)
ax9d.set_ylim(0, 0.2)

axs9e, axg9e = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=6,
    group_top=0.3 + 0.15,
    group_left=0.8 + 2.4 + 0.1 + 0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.020,
    hspace=0.0,
)
figure_xcorr_ccg(
    am_twostim_ccg_75, am_single_ccg_75, axs=axs9e.flatten()[:], colors=c_forebrain_am75
)
for ax in axs9e[0, :]:
    ax.set_ylim(bottom=-0.10e-5, top=1.0e-5)
axs9e[0, 0].set_yticks(np.array([0.5]) * 1e-5, [".5"], minor=False)
# axs3e[0, 0].set_yticks(np.array([0.5]) * 1e-5, minor=True)
axs9e[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)
subplot_indicator(axg9e, "E", ha="left", va="top", pad_inch=0.7)


axs9f, axg9f = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=6,
    group_top=0.3 + 1.0 + 0.15,
    group_left=0.8 + 2.4 + 0.1 + 0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.020,
    hspace=0.0,
)
figure_xcorr_pspectra(
    am_twostim_ccg_75, am_single_ccg_75, axs=axs9f, colors=c_forebrain_am75
)
subplot_indicator(axg9f, "F", ha="left", va="top", pad_inch=0.7)


axs9g, axg9g = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=6,
    group_top=0.4 + 1.0 + 0.15 + 0.6 + 0.4,
    group_left=0.8 + 2.4 + 0.1 + 0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.030,
    hspace=0.0,
)

figure_stimphase_driver_competitor(
    df_twostim_75, axs=axs9g[:], colors=[c_forebrain_am75, c_forebrain_am55]
)


for ax in axs9g[0, :]:
    ax.set_title("")
axg9g.set_title("")
axg9g.set_ylabel("VS Competitor")
axg9g.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(axg9g, "G", ha="left", va="bottom", pad_inch=0.7)

axs9g[0, 0].set_xlim(0, 0.5)
axs9g[0, 0].set_ylim(0, 0.5)


ax9h = figure_add_axes_inch(
    fig,
    bottom=0.5,
    left=0.8 + 2.4 + 0.1 + 0.8,
    width=2.55,
    height=1.5,
    label="H",
)
figure_stimphase_switch(
    df_twostim_75, ax=ax9h, colors=[c_forebrain_am75, c_forebrain_am55]
)
# ax3d.set_ylabel("Change in\nSpike Rate")
# ax3d.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax9h, ha="left", va="top", pad_inch=0.7)
ax9h.set_ylim(0, 0.2)


axg9a.set_title("Driver 55Hz", pad=22, fontsize=12)
axg9e.set_title("Driver 75Hz", pad=22, fontsize=12)
# fig.savefig(OUTDIR / "forebrain_xcorr_pspectra.png")
# fig.outline()
# condition_batch(
#     fig,
#     left=0.8,
#     top=2.3,
#     text="Driver\n55 Hz",
#     color=c_forebrain_am55[0],
#     fontsize=12,
#     ha="right",
# )

fig.savefig(OUTDIR / "figure9.pdf")
fig.savefig(OUTDIR / "figure9.png")
fig.savefig(OUTDIR / "figure9.eps")

## Figure 10 Spatial tuning Correlation distribution 

In [None]:
fig = plt.figure(figsize=(6.80, 10.2))
ax10a = figure_add_axes_inch(
    fig,
    top=0.1,
    left=0.8,
    width=5.8,
    height=1.5,
    label="A",
)
srf_correlation_distribution(single_srf, twostim_ccg, ax10a)
subplot_indicator(ax10a, ha="left", va="top", pad_inch=0.7)


ax10b = figure_add_axes_inch(
    fig,
    top=0.3 + 1.5 + 0.7,
    left=0.8,
    width=2,
    height=2,
    label="B",
)

ax10b_cb = figure_add_axes_inch(
    fig,
    top=0.3 + 1.5 + 0.7,
    left=0.8 + 2 + 0.05,
    width=0.05,
    height=2,
)

figure_spatial_receptive_field(lowcorr_srf_ot, ax=ax10b, cax=ax10b_cb)
subplot_indicator(ax10b, ha="left", va="top", pad_inch=0.7)
ax10b.set_title("OT")

ax10c = figure_add_axes_inch(
    fig,
    top=0.3 + 1.5 + 0.7,
    left=0.8 + 1.4 + 2.0,
    width=2,
    height=2,
    label="C",
)

ax10c_cb = figure_add_axes_inch(
    fig,
    top=0.3 + 1.5 + 0.7,
    left=0.8 + 1.4 + 2.0 + 2 + 0.05,
    width=0.05,
    height=2,
)

figure_spatial_receptive_field(lowcorr_srf_fb, ax=ax10c, cax=ax10c_cb)
subplot_indicator(ax10c, ha="left", va="top", pad_inch=0.7)
ax10c.set_title("nRt")

ax10d = figure_add_axes_inch(
    fig,
    top=0.03 + 2.0 + 1.0 + 1.5 + 0.7,
    left=0.8,
    width=2.0,
    height=2.0,
    label="D",
)

ax10d_cb = figure_add_axes_inch(
    fig,
    top=0.03 + 2.0 + 1.0 + 1.5 + 0.7,
    left=0.8 + 2.0 + 0.05,
    width=0.05,
    height=2.0,
)

figure_spatial_receptive_field(highcorr_srf_ot, ax=ax10d, cax=ax10d_cb)
subplot_indicator(ax10d, ha="left", va="top", pad_inch=0.7)


ax10e = figure_add_axes_inch(
    fig,
    top=0.03 + 2.0 + 1.0 + 1.5 + 0.7,
    left=0.8 + 1.4 + 2.0,
    width=2.0,
    height=2.0,
    label="E",
)

ax10e_cb = figure_add_axes_inch(
    fig,
    top=0.03 + 2.0 + 1.0 + 1.5 + 0.7,
    left=0.8 + 1.4 + 2.0 + 2.0 + 0.05,
    width=0.05,
    height=2.0,
)

figure_spatial_receptive_field(highcorr_srf_fb, ax=ax10e, cax=ax10e_cb)
subplot_indicator(ax10e, ha="left", va="top", pad_inch=0.7)


ax10f = figure_add_axes_inch(
    fig,
    top=0.03 + 2.0 + 1.0 + 1.5 + 0.7 + 2.0 + 0.80,
    left=0.8,
    width=5.8,
    height=1.5,
    label="F",
)
ax10f.set_ylim(0, 20)
# ax8g = figure_add_axes_inch(
#    fig,
#     top=0.03 + 2.0 + 1.0 + 1.5+ 0.7 + 2.0 + .80,
#     left=0.8 + 1.4 + 2.0,
#     width=2.0,
#     height=2.0,
#     label="G",
# )


srf_correlation_frontalperiperpheral(single_srf, twostim_ccg, axs=ax10f)
subplot_indicator(ax10f, ha="left", va="top", pad_inch=0.7)
# subplot_indicator(ax5b, ha="left", va="top", pad_inch=0.7)


# fig.outline()
fig.savefig(OUTDIR / "figure10.pdf")
fig.savefig(OUTDIR / "figure10.png")
fig.savefig(OUTDIR / "figure10.eps")

## Figure 11 Dual Region Xcorr by SRF Correlation

In [None]:
fig = plt.figure(figsize=(6.80, 7.0))

axs11a, axg11a = figure_add_axes_group_inch(
    fig,
    nrows=4,
    ncols=6,
    group_top=0.4,
    group_left=0.8,
    individual_width=(5.10 / 6) * 0.85,
    individual_height=(5.0 / 6) * 0.85,
    wspace=0.15,
    hspace=0.45,
)

ax11a_cb = figure_add_axes_inch(fig, top=0.4, left=6.1, width=0.07, height=(5.0 * 0.85))

fig_dual_region_xcorr_grid(single_srf, twostim_ccg, axs=axs11a, caxs=ax11a_cb)


ax11b = figure_add_axes_inch(
    fig, top=0.3 + 3.6 + 0.5 + 0.30 + 0.3, left=0.8, width=2.6, height=1.5, label="B"
)
csi_boxplot_dualregion(
    single_srf,
    twostim_ccg,
    single_ccg,
    ax=ax11b,
    anova_align="left",
    brackets={(0, 1): 0.62, (1, 2): 0.62, (1, 3): 0.72, (1, 4): 0.82, (1, 5): 0.92},
)


ax11c = figure_add_axes_inch(
    fig,
    top=0.3 + 3.6 + 0.5 + 0.3 + 0.3,
    left=0.8 + 2.1 + 1.2,
    width=2.6,
    height=1.5,
    label="C",
)
csi_boxplot_dualregion_srfcorr(
    single_srf,
    twostim_ccg,
    single_ccg,
    ax=ax11c,
    anova_align="left",
    brackets={(0, 3): 0.9, (0, 2): 0.8},
)

# subplot_indicator(ax3d, ha="left", va="top", pad_inch=0.7)
subplot_indicator(axg11a, "A", ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax11b, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax11c, ha="left", va="bottom", pad_inch=0.7)

# fig.outline()

fig.savefig(OUTDIR / "figure11.pdf")
fig.savefig(OUTDIR / "figure11.png")
fig.savefig(OUTDIR / "figure11.eps")
# fig.outline()

## Figure 12 Within and Across Region Gamma Power

In [None]:
fig = plt.figure(figsize=(6.80, 7.3))

ax12a = figure_add_axes_inch(
    fig,
    top=0.20,
    left=0.8 + 0.3,
    width=2.0,
    height=2.1,
    label="A",
)

ax12a_cb = figure_add_axes_inch(
    fig,
    top=0.20 + 0.3,
    left=0.8 + 2.0 + 0.05 + 0.3,
    width=0.05,
    height=1.5,
)
ax12a.set_title("Within OT")
ax12a.set_ylabel("Spatial Tuning\nCorrelation Percentile", x=-1, fontsize=10)

ax12b = figure_add_axes_inch(
    fig,
    top=0.4 + 0.1 + 2.1,
    left=0.8 + 0.3,
    width=2.0,
    height=2.1,
    label="B",
)
ax12b.set_title("Within nRt")

ax12b_cb = figure_add_axes_inch(
    fig,
    top=0.4 + 0.3 + 0.3 + 1.9,
    left=0.8 + 2.0 + 0.05 + 0.3,
    width=0.05,
    height=1.5,
)
ax12b.set_ylabel("Spatial Tuning\nCorrelation Percentile", x=-1, fontsize=10)

ax12c = figure_add_axes_inch(
    fig,
    top=0.60 + 0.1 + 2.1 + 0.1 + 2.1,
    left=0.8 + 0.3,
    width=2.0,
    height=2.1,
    label="C",
)
ax12c.set_title("Cross Region")

ax12c_cb = figure_add_axes_inch(
    fig,
    top=0.60 + 0.3 + 0.3 + 1.9 + 0.3 + 1.9,
    left=0.8 + 2.0 + 0.05 + 0.3,
    width=0.05,
    height=1.5,
)
ax12c.set_ylabel("Spatial Tuning\nCorrelation Percentile", x=-1, fontsize=10)


ax12d = figure_add_axes_inch(
    fig,
    top=0.1 + 0.4,
    left=0.8 + 2.1 + 0.05 + 1.5,
    width=2.0,
    height=1.5,
    label="D",
)

ax12e = figure_add_axes_inch(
    fig,
    top=0.5 + 0.1 + 2.3,
    left=0.8 + 2.1 + 0.05 + 1.5,
    width=2.0,
    height=1.5,
    label="E",
)

ax12f = figure_add_axes_inch(
    fig,
    top=0.9 + 0.1 + 2.1 + 0.1 + 2.1,
    left=0.8 + 2.1 + 0.05 + 1.5,
    width=2.0,
    height=1.5,
    label="F",
)

lowgamma = twostim_gamma_power[twostim_gamma_power["gamma_band"] == "lowgamma"]
gamma_heatmap_scatter_resubmit(
    single_srf,
    twostim_ccg,
    lowgamma,
    [ax12a, ax12b, ax12c, ax12d, ax12e, ax12f],
    [ax12a_cb, ax12b_cb, ax12c_cb],
)

subplot_indicator(ax12a, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax12b, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax12c, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax12d, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax12e, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax12f, ha="left", va="bottom", pad_inch=0.7)
# fig.outline()

fig.savefig(OUTDIR / "figure12.pdf")
fig.savefig(OUTDIR / "figure12.png")
fig.savefig(OUTDIR / "figure12.eps")

## Figure 13 Spike-LFP Coherence 

In [None]:
## spike field coherence figure
fig = plt.figure(figsize=(6.80, 5.7))

axs13a, axg13a = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=3,
    group_top=0.6,
    group_left=0.8,
    individual_width=2.4 / 3,
    individual_height=1.0,
    wspace=0.05,
    hspace=0.0,
)
within_area_sfc_competition(
    within_area_sfc[within_area_sfc["region"] == "OT"], axs=axs13a[0, :], colors=c_ot
)
axg13a.set_title("OT", y=1.3)


axs13b, axg13b = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=3,
    group_top=0.6,
    group_left=0.8 + 2.4 + 0.8,
    individual_width=2.4 / 3,
    individual_height=1.0,
    wspace=0.05,
    hspace=0.0,
)
within_area_sfc_competition(
    within_area_sfc[within_area_sfc["region"] == "Forebrain"],
    axs=axs13b[0, :],
    colors=c_forebrain,
)
axs13b[0, 0].set_ylabel(None)
axg13b.set_title("nRt", y=1.3)

axs13c, axg13c = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=3,
    group_top=0.6 + 1.0 + 0.8,
    group_left=0.8,
    individual_width=2.4 / 3,
    individual_height=1.0,
    wspace=0.05,
    hspace=0.0,
)
axs13c[0, 0].set_ylabel("Normalized\nSpike Count")
axg13c.set_title("Low Gamma 20-50Hz")
gamma_band = twostim_gamma_power[
    (twostim_gamma_power["gamma_band"] == "lowgamma")
    & (twostim_gamma_power["region"] == "OT")
]
figure_withinregion_spike_phase(gamma_band, colors="b", axs=axs13c[0, :])

axs13d, axg13d = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=3,
    group_top=0.6 + 1.0 + 0.8,
    group_left=0.8 + 2.4 + 0.8,
    individual_width=2.4 / 3,
    individual_height=1.0,
    wspace=0.05,
    hspace=0.0,
)
axg13d.set_title("Low Gamma 20-50Hz")
gamma_band = twostim_gamma_power[
    (twostim_gamma_power["gamma_band"] == "lowgamma")
    & (twostim_gamma_power["region"] == "Forebrain")
]
figure_withinregion_spike_phase(gamma_band, colors="orange", axs=axs13d[0, :])


axs13e, axg13e = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=3,
    group_top=0.6 + 1.0 + 0.6 + 1.0 + 1.0,
    group_left=0.8,
    individual_width=2.4 / 3,
    individual_height=1.0,
    wspace=0.05,
    hspace=0.0,
)
axs13e[0, 0].set_ylabel("Normalized\nSpike Count")
axg13e.set_title("High Gamma 50-75Hz")
gamma_band = twostim_gamma_power[
    (twostim_gamma_power["gamma_band"] == "highgamma")
    & (twostim_gamma_power["region"] == "OT")
]
figure_withinregion_spike_phase(gamma_band, colors="b", axs=axs13e[0, :])

axs13f, axg13f = figure_add_axes_group_inch(
    fig,
    nrows=1,
    ncols=3,
    group_top=0.6 + 1.0 + 0.6 + 1.0 + 1.0,
    group_left=0.8 + 2.4 + 0.8,
    individual_width=2.4 / 3,
    individual_height=1.0,
    wspace=0.05,
    hspace=0.0,
)
gamma_band = twostim_gamma_power[
    (twostim_gamma_power["gamma_band"] == "highgamma")
    & (twostim_gamma_power["region"] == "Forebrain")
]
figure_withinregion_spike_phase(gamma_band, colors="orange", axs=axs13f[0, :])
axg13f.set_title("High Gamma 50-75Hz")


subplot_indicator(axg13a, "A", ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(axg13b, "B", ha="left", va="bottom", pad_inch=0.5)
subplot_indicator(axg13c, "C", ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(axg13d, "D", ha="left", va="bottom", pad_inch=0.5)
subplot_indicator(axg13e, "E", ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(axg13f, "F", ha="left", va="bottom", pad_inch=0.5)

fig.savefig(OUTDIR / "figure13.pdf")
fig.savefig(OUTDIR / "figure13.png")
fig.savefig(OUTDIR / "figure13.eps")