# Imports and Settings

Imports

In [None]:
# Builtin
import pathlib
from typing import Sequence, cast

# 3rd party
import numpy as np
import pandas as pd
import scipy.stats
import scipy.signal
import matplotlib.pyplot as plt
from matplotlib.image import imread
from matplotlib.transforms import ScaledTranslation
from matplotlib.ticker import MultipleLocator, FuncFormatter
from matplotlib.transforms import blended_transform_factory
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.patches import Patch

# Custom
from powltools.analysis.grouping import group_by_param
from funcs_plotting import (
    plot_bracket,
    condition_batch,
    figure_outline,
    figure_add_axes_inch,
    figure_add_axes_group_inch,
    subplot_indicator,
    ColorType,
)
from funcs_statistics import anova_tukey, t_test_ind
from funcs_common import binary_spiketrain, make_psth_bins

In [None]:
# Do not crop inline Figures:
%config InlineBackend.print_figure_kwargs = {'bbox_inches': None}

Path settings

In [None]:
# Directory with intermediate results saved as .feather files
DATADIR = pathlib.Path("./intermediate_results").resolve()

# Output directory for figures
OUTDIR = pathlib.Path("./figure_output").resolve()
OUTDIR.mkdir(exist_ok=True)

def save_show_close(fig: Figure, name: str):
    fig.savefig(str(OUTDIR / f"{name}.pdf"))
    fig.savefig(str(OUTDIR / f"{name}.png"))
    # fig.savefig(str(OUTDIR / f"{name}.eps"))
    display(fig)
    w, h = fig.get_size_inches()
    print(f"Figure size: {w:.2f} x {h:.2f} inches == {w*2.54:.2f} x {h*2.54:.2f} cm")
    plt.close(fig)

Color Schemes

In [None]:
## Flat noise competion, colors for relative levels (from -15 to +10)
c_flat_noise = [
    "mediumblue",
    "cornflowerblue",
    "lightsteelblue",
    "slategray",
    "mediumseagreen",
    "forestgreen",
]
c_flat_noise = ["mediumblue"] * len(c_flat_noise)
c_fAM_55 = ["deeppink", "hotpink", "orchid", "slategray", "mediumpurple", "darkorchid"]
c_fAM_55 = ["deeppink"] * len(c_fAM_55)
c_fAM_75 = ["darkorchid", "mediumpurple", "plum", "slategray", "hotpink", "deeppink"]
c_fAM_75 = ["darkorchid"] * len(c_fAM_75)

c_ot = "#D98B4C"
c_icx = "#4BC0D9"

In [None]:
## Constants

# From calibration
INTENSITY_0_EQUALS_DB_SPL = +63.0

# Threshold value separating short and long latency units:
SHORT_LATENCY_THRESHOLD = 0.014

Load Data from intermediate results

In [None]:
## Optic Tectum (OT)

# Flat noise - single stimuli
single_rlf_ot = pd.read_feather(DATADIR / "single_rlf_ot.feather").set_index(
    ["date", "owl", "channel"]
)
single_rlf_ot["short_latency"] = (
    single_rlf_ot["first_spike_latency"] <= SHORT_LATENCY_THRESHOLD
)
single_rlf_ot["long_latency"] = ~single_rlf_ot["short_latency"]

single_ccg_ot = pd.read_feather(DATADIR / "single_ccg_ot.feather").set_index(
    ["date", "owl", "channel1", "channel2"]
)
single_gamma_power_ot = pd.read_feather(
    DATADIR / "single_gamma_power_ot.feather"
).set_index(["date", "owl", "channel"])

# Flat noise - competing stimuli
twostim_rlf_ot = pd.read_feather(DATADIR / "twostim_rlf_ot.feather").set_index(
    ["date", "owl", "channel"]
)
twostim_ccg_ot = pd.read_feather(DATADIR / "twostim_ccg_ot.feather").set_index(
    ["date", "owl", "channel1", "channel2"]
)
twostim_gamma_power_ot = pd.read_feather(
    DATADIR / "twostim_gamma_power_ot.feather"
).set_index(["date", "owl", "channel"])

# AM noise - single stimuli
am_single_rlf_ot = pd.read_feather(DATADIR / "am_single_rlf_ot.feather").set_index(
    ["date", "owl", "channel"]
)
am_single_rlf_ot["short_latency"] = (
    am_single_rlf_ot["first_spike_latency"] <= SHORT_LATENCY_THRESHOLD
)
am_single_rlf_ot["long_latency"] = ~am_single_rlf_ot["short_latency"]

am_single_ccg_ot = pd.read_feather(DATADIR / "am_single_ccg_ot.feather").set_index(
    ["date", "owl", "channel1", "channel2"]
)
am_single_stim_phaselocking_ot = pd.read_feather(
    DATADIR / "am_single_stim_phaselocking_ot.feather"
).set_index(["date", "owl", "channel"])
am_single_stim_phaselocking_ot["short_latency"] = (
    am_single_stim_phaselocking_ot["first_spike_latency"] <= SHORT_LATENCY_THRESHOLD
)
am_single_stim_phaselocking_ot["long_latency"] = ~am_single_stim_phaselocking_ot[
    "short_latency"
]

am_single_gamma_power_ot = pd.read_feather(
    DATADIR / "am_single_gamma_power_ot.feather"
).set_index(["date", "owl", "channel"])

# AM noise - competing stimuli
am_twostim_rlf_ot = pd.read_feather(DATADIR / "am_twostim_rlf_ot.feather").set_index(
    ["date", "owl", "channel"]
)
am_twostim_ccg_ot = pd.read_feather(DATADIR / "am_twostim_ccg_ot.feather").set_index(
    ["date", "owl", "channel1", "channel2"]
)
am_twostim_stim_phaselocking_ot = pd.read_feather(
    DATADIR / "am_twostim_stim_phaselocking_ot.feather"
).set_index(["date", "owl", "channel"])
am_twostim_gamma_power_ot = pd.read_feather(
    DATADIR / "am_twostim_gamma_power_ot.feather"
).set_index(["date", "owl", "channel"])

# Elevation tuning
elevation_tunings_ot = pd.read_feather(DATADIR / "elevation_tunings_ot.feather")
elevation_signalcorr_ot = pd.read_feather(DATADIR / "elevation_signalcorr_ot.feather")

In [None]:
## External nucleus of the Inferior Colliculus (ICx)

# Flat noise - single stimuli
single_rlf_icx = pd.read_feather(DATADIR / "single_rlf_icx.feather").set_index(
    ["date", "owl", "channel"]
)
single_ccg_icx = pd.read_feather(DATADIR / "single_ccg_icx.feather").set_index(
    ["date", "owl", "channel1", "channel2"]
)
single_gamma_power_icx = pd.read_feather(
    DATADIR / "single_gamma_power_icx.feather"
).set_index(["date", "owl", "channel"])

# Flat noise - competing stimuli
twostim_rlf_icx = pd.read_feather(DATADIR / "twostim_rlf_icx.feather").set_index(
    ["date", "owl", "channel"]
)
twostim_ccg_icx = pd.read_feather(DATADIR / "twostim_ccg_icx.feather").set_index(
    ["date", "owl", "channel1", "channel2"]
)
twostim_gamma_power_icx = pd.read_feather(
    DATADIR / "twostim_gamma_power_icx.feather"
).set_index(["date", "owl", "channel"])

# AM noise - single stimuli
am_single_rlf_icx = pd.read_feather(DATADIR / "am_single_rlf_icx.feather").set_index(
    ["date", "owl", "channel"]
)
am_single_ccg_icx = pd.read_feather(DATADIR / "am_single_ccg_icx.feather").set_index(
    ["date", "owl", "channel1", "channel2"]
)
am_single_stim_phaselocking_icx = pd.read_feather(
    DATADIR / "am_single_stim_phaselocking_icx.feather"
).set_index(["date", "owl", "channel"])

# AM noise - competing stimuli
am_twostim_rlf_icx = pd.read_feather(DATADIR / "am_twostim_rlf_icx.feather").set_index(
    ["date", "owl", "channel"]
)
am_twostim_ccg_icx = pd.read_feather(DATADIR / "am_twostim_ccg_icx.feather").set_index(
    ["date", "owl", "channel1", "channel2"]
)
am_twostim_stim_phaselocking_icx = pd.read_feather(
    DATADIR / "am_twostim_stim_phaselocking_icx.feather"
).set_index(["date", "owl", "channel"])

# Elevation tuning
elevation_tunings_icx = pd.read_feather(DATADIR / "elevation_tunings_icx.feather")
elevation_signalcorr_icx = pd.read_feather(DATADIR / "elevation_signalcorr_icx.feather")

In [None]:
## Example data

# Exemplary spatial receptive field (OT)
example_srf = pd.read_feather(DATADIR / "example_srf_20230420_33.feather")

# Examplary Spiketrains of two simultanously recorded units (OT)
example_spiketrains_ot_flat = pd.read_feather(
    DATADIR / "example_spiketrains_ot_flat_20230523_40.feather"
)
example_spiketrains_ot_driver55 = pd.read_feather(
    DATADIR / "example_spiketrains_ot_driver55_20230523_40.feather"
)
example_spiketrains_ot_driver75 = pd.read_feather(
    DATADIR / "example_spiketrains_ot_driver75_20230523_40.feather"
)

# Examplary Spiketrains of two simultanously recorded units (ICx)
example_spiketrains_icx_flat = pd.read_feather(
    DATADIR / "example_spiketrains_icx_flat_20230523_40.feather"
)
example_spiketrains_icx_driver55 = pd.read_feather(
    DATADIR / "example_spiketrains_icx_driver55_20230523_40.feather"
)
example_spiketrains_icx_driver75 = pd.read_feather(
    DATADIR / "example_spiketrains_icx_driver75_20230523_40.feather"
)

# Plot Functions

In [None]:
# Execute example plotting code:
do_usage = False # or True

## Spatial Receptive Field Plot

In [None]:
def figure_spatial_receptive_field(
    srf_data: pd.DataFrame,
    ax: Axes,
    cax: Axes | None = None,
):
    azimuths = np.asarray(srf_data["azimuth"].values)
    elevations = np.asarray(srf_data["elevation"].values)
    response_rates = np.asarray(srf_data["response_mean"].values)

    srf = ax.scatter(
        x=azimuths,
        y=elevations,
        c=response_rates,
        cmap="Greys",
        edgecolors="0.7",
        vmin=0.0,
        linewidths=0.5,
    )
    cb = cast(Figure, ax.figure).colorbar(srf, cax=cax)
    cb.set_label("Response [spks/stim]")

    ax.xaxis.set_major_locator(MultipleLocator(30))
    ax.set_xlabel("Azimuth [deg]")
    ax.yaxis.set_major_locator(MultipleLocator(30))
    ax.set_ylabel("Elevation [deg]")


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.80, 2.6)))
    ax = figure_add_axes_inch(
        fig,
        top=0.05,
        left=0.8,
        width=2.4,
        height=2.0,
        label="A",
    )
    ax_cb = figure_add_axes_inch(
        fig,
        top=0.05,
        left=0.8 + 2.4 + 0.05,
        width=0.05,
        height=2.0,
    )
    figure_spatial_receptive_field(example_srf, ax=ax, cax=ax_cb)
    display(fig)
    plt.close(fig)
    del fig, ax, ax_cb

## RLF Boxplot

In [None]:
def figure_rlf_boxplot(
    twostim_rlf: pd.DataFrame,
    single_rlf: pd.DataFrame,
    ax: Axes,
    colors: Sequence[ColorType] = c_flat_noise,
    brackets: dict[tuple[int, int], float] | None = None,
    anova_align="right",
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    merged_rlf = twostim_rlf.join(single_rlf, how="inner", rsuffix="_single")
    merged_rlf = merged_rlf.loc[
        (merged_rlf["fixedintensity"] == merged_rlf["intensity"])
        & (merged_rlf["fixedazi"] == merged_rlf["azimuth"])
        & (merged_rlf["fixedele"] == merged_rlf["elevation"])
    ]

    print("Number of units:", np.unique(merged_rlf.index).size)
    print(
        "Number of sessions:",
        len(set([idx[:2] for idx in merged_rlf.index.values])),
    )
    merged_rlf.set_index("relative_level", inplace=True)
    # display(merged_rlf_ot)

    merged_rlf["change_in_response"] = (
        merged_rlf["resp"] - merged_rlf["resp_single"]
    ) / merged_rlf["resp_single"]
    relative_levels = np.unique(merged_rlf.index)
    groupdata = [
        cast(pd.Series, merged_rlf.loc[relative_level, "change_in_response"]).values
        for relative_level in relative_levels
    ]
    for k, lev in enumerate(relative_levels):
        print(
            f"{lev}: {np.median(np.asarray(groupdata[k]))} , sem: {scipy.stats.sem(groupdata[k])}"
        )

    bp = ax.boxplot(
        groupdata,
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
    )

    ax.set_xticklabels(map(str, relative_levels.astype(int)))

    stats = anova_tukey(
        merged_rlf, val_col="change_in_response", group_col="relative_level"
    )

    for patch, color in zip(bp["boxes"], colors):
        patch: Patch
        patch.set_facecolor(color)

    for median in bp["medians"]:
        median.set(color="white", linewidth=2)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate(groupdata):
            pos = k + 1
            y = np.sort(cast(np.ndarray, groupvalues))
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker=".",
                markersize=2,
                markeredgewidth=0.0,
                zorder=5,
            )

    if True:  # show_brackets
        brackets_shrink = 1.0
        if brackets is None:
            brackets = {(k, k + 1): 0.8 for k in range(len(stats["tukey"]) - 1)}
            brackets_shrink = 0.8
        for (idx_left, idx_right), y in brackets.items():
            plot_bracket(
                ax,
                left=idx_left + 1,
                right=idx_right + 1,
                text=f"{stats['tukey'].iloc[idx_left, idx_right]:.3f}"[1:],
                y=y,
                shrink=brackets_shrink,
            )

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    if anova_align == "right":
        ax.text(
            1,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="right",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes,
        )
    else:
        ax.text(
            0,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="left",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes
            + ScaledTranslation(
                +2 / 72, 0 / 72, cast(Figure, ax.figure).dpi_scale_trans
            ),
        )

    ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.0%}"))

    ax.axhline(0, color="k", ls="--", zorder=-1)
    ax.set_xlabel("Relative Level [dB]")
    ax.set_ylabel("Change in Spike Rate")
    ax.set_ylim(-1, 1.5)


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.1)))
    ax = figure_add_axes_inch(
        fig,
        top=0.1,
        left=0.8,
        width=2.4,
        height=1.5,
    )
    figure_rlf_boxplot(
        twostim_rlf_ot,
        single_rlf_ot,
        ax=ax,
        brackets={(0, 1): 0.90, (1, 2): 0.80, (2, 3): 0.70, (3, 4): 0.80, (4, 5): 0.70},
    )
    figure_outline(fig)
    display(fig)
    plt.close(fig)
    del fig, ax

## Coincident Rasterplot

In [None]:
def figure_coincident_rasterplot(
    spiketrain_data: pd.DataFrame,
    axs: list[Axes],
    relative_levels: Sequence[float] = [-15, +10],
    n_trials: int = 10,
):
    assert len(axs) == len(relative_levels)

    spiketrains_unit1 = np.asarray(spiketrain_data["spiketrain_unit1"].values)
    spiketrains_unit2 = np.asarray(spiketrain_data["spiketrain_unit2"].values)
    trial_relative_levels = np.asarray(spiketrain_data["relative_level"].values)

    stim_spikes1 = group_by_param(spiketrains_unit1, trial_relative_levels)
    stim_spikes2 = group_by_param(spiketrains_unit2, trial_relative_levels)

    psth_bins = make_psth_bins(
        stimdelay=1.0,
        stimduration=1.0,
        binsize=0.001,
        offset=0.000,
    )

    for k, relative_level in enumerate(relative_levels):
        ax = axs[k]
        spiketrains1 = stim_spikes1[relative_level][:n_trials]
        spiketrains2 = stim_spikes2[relative_level][:n_trials]
        coincidences = [
            np.where(
                binary_spiketrain(st1, bins=psth_bins)
                & binary_spiketrain(st2, bins=psth_bins)
            )[0]
            * 0.001
            + psth_bins[0]
            for st1, st2 in zip(spiketrains1, spiketrains2)
        ]
        ax.eventplot(
            spiketrains1,
            lineoffsets=(4 * np.arange(spiketrains1.size)).tolist(),
            linelengths=1,
            color="#00dd33",
            alpha=0.3,
        )
        ax.eventplot(
            spiketrains2,
            lineoffsets=(2 + 4 * np.arange(spiketrains2.size)).tolist(),
            linelengths=1,
            color="#ffaa00",
            alpha=0.3,
        )
        ax.eventplot(
            coincidences,
            lineoffsets=(1 + 4 * np.arange(len(coincidences))).tolist(),
            linelengths=1.4,  # type: ignore
            color="k",
        )
        ax.axvspan(1.0, 1.050, color="#ddf", zorder=-50)
        ax.set_xlim(left=1.0, right=2.0)
        ax.set_xticks([1.0, 1.5, 2.0], labels=["0", "0.5", "1"], fontsize=8)
        ax.set_ylim(bottom=-0.5, top=2 + 4 * (spiketrains2.size - 1) + 0.5)

    for ax in axs[1:]:
        ax.set_yticks([])

    axs[0].set_xlabel("Time [s]", fontsize=8)
    ytick_trial = 1
    axs[0].set_yticks(
        [4 * ytick_trial, 2 + 4 * ytick_trial, 1 + 4 * (ytick_trial + 2)],
        ["", "", ""],
    )
    axs[0].tick_params(axis="y", length=4)
    axs[0].text(
        0,
        4 * ytick_trial,
        "Unit A",
        fontsize=10,
        ha="right",
        va="top",
        color="k",
        transform=blended_transform_factory(axs[0].transAxes, axs[0].transData)
        + ScaledTranslation(-4 / 72, 0, cast(Figure, axs[0].figure).dpi_scale_trans),
    )

    axs[0].text(
        0,
        2 + 4 * ytick_trial,
        "Unit B",
        fontsize=10,
        ha="right",
        va="bottom",
        color="k",
        transform=blended_transform_factory(axs[0].transAxes, axs[0].transData)
        + ScaledTranslation(-4 / 72, 0, cast(Figure, axs[0].figure).dpi_scale_trans),
    )

    axs[0].text(
        0,
        1 + 4 * (ytick_trial + 2),
        "Coinc.",
        fontsize=10,
        ha="right",
        va="bottom",
        color="k",
        transform=blended_transform_factory(axs[0].transAxes, axs[0].transData)
        + ScaledTranslation(-4 / 72, 0, cast(Figure, axs[0].figure).dpi_scale_trans),
    )


if do_usage:
    relative_levels = [-15, -10, -5, 0, +5, +10]
    fig = cast(Figure, plt.figure(figsize=(0.8 + len(relative_levels) * 1.25, 5.2)))
    axs, _ = figure_add_axes_group_inch(
        fig,
        nrows=2,
        ncols=len(relative_levels),
        group_top=0.1 + 0.2,
        group_left=0.8,
        individual_width=1.15,
        individual_height=2.0,
        wspace=0.1,
        hspace=0.4,
    )
    figure_coincident_rasterplot(
        example_spiketrains_ot_flat,
        axs=list(axs[0, :].flatten()),
        relative_levels=relative_levels,
    )
    figure_coincident_rasterplot(
        example_spiketrains_icx_flat,
        axs=list(axs[1, :].flatten()),
        relative_levels=relative_levels,
    )
    for relative_level, ax in zip(relative_levels, axs[0, :].flatten()):
        ax.set_title(f"{relative_level:+} dB")
    subplot_indicator(axs[0, 0], "OT", va="top")
    subplot_indicator(axs[1, 0], "ICx", va="top")
    display(fig)
    plt.close(fig)
    del fig, axs, relative_levels, relative_level, ax

## Correlation Cross-Correlograms

In [None]:
def figure_xcorr_ccg(
    twostim_ccg: pd.DataFrame,
    single_ccg: pd.DataFrame,
    axs: list[Axes],
    colors: Sequence[ColorType] = c_flat_noise,
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    merged_ccg = twostim_ccg.join(single_ccg, how="inner", rsuffix="_single")
    merged_ccg = merged_ccg.loc[
        (merged_ccg["fixedintensity"] == merged_ccg["intensity"])
    ]
    merged_ccg = merged_ccg.loc[(merged_ccg["fixedazi"] == merged_ccg["azimuth"])]
    merged_ccg = merged_ccg.loc[(merged_ccg["fixedele"] == merged_ccg["elevation"])]

    print("Number of unit pairs:", np.unique(merged_ccg.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_ccg.index.values]))),
    )

    # Before reindexing by relative_level:
    single_mean_peak = np.max(
        np.mean(
            np.vstack(
                cast(
                    Sequence[float],
                    merged_ccg["ccg_single"].groupby(merged_ccg.index.names).first(),
                )
            ),
            axis=0,
        )
    )
    single_std_peak = np.std(
        np.max(
            np.vstack(
                cast(
                    Sequence[float],
                    merged_ccg["ccg_single"].groupby(merged_ccg.index.names).first(),
                )
            ),
            axis=0,
        )
    )
    print(f"mean: {single_mean_peak = :.4g} std: {single_std_peak}")

    merged_ccg.set_index("relative_level", inplace=True)

    psth_len = (merged_ccg.iloc[0]["ccg_single"].size + 1) / 2
    lags = scipy.signal.correlation_lags(psth_len, psth_len) / 1000
    lags_mask = np.abs(lags) <= 0.05

    relative_levels = np.unique(merged_ccg.index)

    for k, relative_level in enumerate(relative_levels):
        mean_ccg = np.mean(
            np.vstack(cast(Sequence[float], merged_ccg.loc[relative_level, "ccg"])),
            axis=0,
        )
        n_ccg = len(merged_ccg.loc[relative_level])
        print(n_ccg)

        axs[k].axhline(single_mean_peak, lw=1, color=colors[0])
        axs[k].plot(lags[lags_mask], mean_ccg[lags_mask], color="k", ls="-", lw=1)
        axs[k].text(
            0.5,
            1,
            f"{n_ccg}",
            ha="center",
            va="top",
            fontsize=8,
            transform=axs[k].transAxes,
        )
        axs[k].spines["top"].set_visible(False)
        axs[k].spines["right"].set_visible(False)

        axs[k].set_ylim(bottom=0, top=0.000075)
        axs[k].set_xlim(left=-0.060, right=0.060)

    axs[0].set_ylabel("Coinc./spk", labelpad=12)

    axs[0].set_yticks([0, 5e-5])
    axs[0].yaxis.set_major_formatter(
        FuncFormatter(lambda x, pos: f"${{{x*1e5:.0f}}}$" if x > 0 else "0")
    )
    axs[0].set_yticks(np.array([1, 2, 3, 4, 6, 7]) * 1e-5, minor=True)
    axs[0].text(
        0,
        1,
        "$10^{-5}$",
        ha="right",
        va="center",
        fontsize=8,
        transform=axs[0].transAxes
        + ScaledTranslation(-2 / 72, 0, cast(Figure, axs[0].figure).dpi_scale_trans),
    )

    for ax in axs[1:]:
        ax.set_yticks([])
        ax.set_yticks([], minor=True)
        ax.spines["left"].set_visible(False)

    axs[0].set_xticks([-0.05, 0, +0.05], ["-50", "0", "+50"], fontsize=8)
    axs[0].set_xlabel("Lags [ms]", fontsize=8, labelpad=0)
    for ax in axs[1:]:
        ax.set_xticks([-0.05, 0, +0.05], [])


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 1.1)))
    axs, axg = figure_add_axes_group_inch(
        fig,
        nrows=1,
        ncols=6,
        group_top=0.1,
        group_left=0.8,
        individual_width=0.4,
        individual_height=0.6,
        wspace=0.0,
    )
    figure_xcorr_ccg(twostim_ccg_ot, single_ccg_ot, axs=axs.flatten().tolist())

## Correlation CSI Boxplots

In [None]:
def figure_xcorr_boxplot(
    twostim_rlf: pd.DataFrame,
    single_rlf: pd.DataFrame,
    ax: Axes,
    colors: Sequence[ColorType] = c_flat_noise,
    brackets: dict[tuple[int, int], float] | None = None,
    anova_align="right",
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    merged_rlf = twostim_rlf.join(single_rlf, how="inner", rsuffix="_single")
    merged_rlf = merged_rlf.loc[
        (merged_rlf["fixedintensity"] == merged_rlf["intensity"])
    ]
    merged_rlf = merged_rlf.loc[(merged_rlf["fixedazi"] == merged_rlf["azimuth"])]
    merged_rlf = merged_rlf.loc[(merged_rlf["fixedele"] == merged_rlf["elevation"])]

    print("Number of units:", np.unique(merged_rlf.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_rlf.index.values]))),
    )
    merged_rlf.set_index("relative_level", inplace=True)

    merged_rlf["csi"] = (merged_rlf["xcorr_peak"] - merged_rlf["xcorr_peak_single"]) / (
        merged_rlf["xcorr_peak"] + merged_rlf["xcorr_peak_single"]
    )
    relative_levels = np.unique(merged_rlf.index)
    groupdata = [
        np.asarray(merged_rlf.loc[relative_level, "csi"])
        for relative_level in relative_levels
    ]

    bp = ax.boxplot(
        groupdata,
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
    )

    ax.set_xticklabels(relative_levels.astype(int))

    stats = anova_tukey(merged_rlf, val_col="csi", group_col="relative_level")

    for patch, color in zip(bp["boxes"], colors):
        patch: Patch
        patch.set_facecolor(color)

    for median in bp["medians"]:
        median.set(color="white", linewidth=2)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate(groupdata):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker=".",
                markersize=2,
                markeredgewidth=0.0,
                zorder=5,
            )

    if True:  # show_brackets
        brackets_shrink = 1.0
        if brackets is None:
            brackets = {(k, k + 1): 0.8 for k in range(len(stats["tukey"]) - 1)}
            brackets_shrink = 0.8
        for (idx_left, idx_right), y in brackets.items():
            plot_bracket(
                ax,
                left=idx_left + 1,
                right=idx_right + 1,
                text=f"{stats['tukey'].iloc[idx_left, idx_right]:.3f}"[1:],
                y=y,
                shrink=brackets_shrink,
            )

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    if anova_align == "right":
        ax.text(
            1,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="right",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes,
        )
    else:
        ax.text(
            0,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="left",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes
            + ScaledTranslation(+2 / 72, 0 / 72, cast(Figure, ax.figure).dpi_scale_trans),
        )

    ax.axhline(0, color="k", ls="--", zorder=-1)
    ax.set_xlabel("Relative Level [dB]")
    ax.set_ylabel("CSI")
    ax.set_ylim(-1, 1.45)


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.1)))
    ax = figure_add_axes_inch(
        fig,
        top=0.1,
        left=0.8,
        width=2.4,
        height=1.5,
    )

    figure_xcorr_boxplot(twostim_ccg_ot, single_ccg_ot, ax=ax)
    figure_outline(fig)

## Stim Phase Locking Plot

In [None]:
def figure_stimphase_single_scatterplot(
    df: pd.DataFrame, axs: list[Axes], color_phase_locking="k"
):
    df = df.copy()
    df["level"] = df["intensity"] + INTENSITY_0_EQUALS_DB_SPL
    df.set_index("level", inplace=True)

    for k, level in enumerate(reversed(np.unique(df.index))):
        ax = axs[k]

        level_rayleigh_p = np.asarray(df.loc[level, "singlestim_plv_p"])
        level_plv = np.asarray(df.loc[level, "singlestim_plv"])
        level_angles = np.asarray(df.loc[level, "singlestim_plv_angle"])

        level_phases = np.cos(level_angles) + np.sin(level_angles) * 1j

        # Filter phases:
        mask = level_rayleigh_p < 0.05

        mean_vector = np.mean(level_phases[mask])
        mean_angle = np.angle(mean_vector).item()
        # vector_strength = np.abs(np.mean(level_phases))

        ax.plot(
            level_plv[~mask],
            level_angles[~mask],
            ls="none",
            marker=".",
            color=".5",
            mec="none",
        )
        ax.plot(
            level_plv[mask],
            level_angles[mask],
            ls="none",
            marker=".",
            color=color_phase_locking,
            mec="none",
        )
        ax.axhline(mean_angle, color="k", lw=2)  # type: ignore
        ax.set_title(f"{level:.0f}", fontsize=10)

        ax.set_xlim(left=0, right=1)
        ax.set_xticks([0, 1], [])
        ax.set_xticks([0.5], minor=True)

        ax.set_ylim(bottom=-np.pi, top=+np.pi)
        ax.set_yticks(np.arange(-np.pi, np.pi + np.pi / 2, np.pi / 2), [])

        if k == 0:
            ax.set_ylabel("Phase")
            ax.set_yticklabels(["-π", "-π/2", "0", "π/2", "π"])
            ax.set_xticklabels(["0", "1"])
        if k == 4:
            ax.set_xlabel("Vector Strength")


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.80, 3.2 + 0.2)))
    axs, axg = figure_add_axes_group_inch(
        fig,
        nrows=2,
        ncols=10,
        group_top=0.3,
        group_left=0.8,
        individual_width=0.225,
        individual_height=1.0,
        wspace=0.05,
        hspace=0.7,
    )

    figure_stimphase_single_scatterplot(
        am_single_stim_phaselocking_ot[
            am_single_stim_phaselocking_ot["modulation_frequency"] == 55
        ],
        axs=axs[0, :].tolist(),
        color_phase_locking=c_fAM_55[0],
    )

    figure_stimphase_single_scatterplot(
        am_single_stim_phaselocking_ot[
            am_single_stim_phaselocking_ot["modulation_frequency"] == 75
        ],
        axs=axs[1, :].tolist(),
        color_phase_locking=c_fAM_75[0],
    )

## Stim Phase Locking Plot Competition

In [None]:
def figure_stimphase_twostim_scatterplot(
    df: pd.DataFrame,
    axs: list[Axes],
    colors: Sequence[ColorType] = c_flat_noise,
):
    df = df.set_index("relative_level")

    for k, level in enumerate(np.unique(df.index)):
        ax = axs[k]

        level_rayleigh_p = np.asarray(df.loc[level, "fixedstim_plv_p"])
        level_plv = np.asarray(df.loc[level, "fixedstim_plv"])
        level_angles = np.asarray(df.loc[level, "fixedstim_plv_angle"])

        level_phases = np.cos(level_angles) + np.sin(level_angles) * 1j

        # Filter phases:
        mask = level_rayleigh_p < 0.05

        mean_vector = np.mean(level_phases[mask]).item()
        mean_angle = np.angle(mean_vector).item()
        vector_strength = abs(np.mean(level_phases).item())

        ax.plot(
            level_plv[~mask],
            level_angles[~mask],
            ls="none",
            marker=".",
            color=".5",
            mec="none",
        )
        ax.plot(
            level_plv[mask],
            level_angles[mask],
            ls="none",
            marker=".",
            color=colors[0],
            mec="none",
        )
        ax.axhline(mean_angle, color="k", lw=2)  # type: ignore
        ax.text(
            1,
            0,
            f"{vector_strength:.2f}",
            ha="right",
            va="bottom",
            fontsize=8,
            transform=ax.transAxes
            + ScaledTranslation(
                -1 / 72, +1 / 72, cast(Figure, ax.figure).dpi_scale_trans
            ),
        )

        ax.set_title(f"{level:.0f}", fontsize=10)

        ax.set_xlim(left=0, right=1)
        ax.set_xticks([0, 1], [])
        ax.set_xticks([0.5], minor=True)

        ax.set_ylim(bottom=-np.pi, top=+np.pi)
        ax.set_yticks([])

        if k == 0:
            ax.set_ylabel("Phase")
            ax.set_yticks(
                np.arange(-np.pi, np.pi + np.pi / 2, np.pi / 2),
                labels=["-π", "-π/2", "0", "π/2", "π"],
            )
            ax.set_xticklabels(["0", "1"])
        if k == 3:
            ax.set_xlabel("Vector Strength")


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.4, 3.3)))
    axs, axg = figure_add_axes_group_inch(
        fig,
        2,
        6,
        group_left=0.8,
        group_top=0.3,
        individual_width=(2.4 - 5 * 0.05) / 6,
        individual_height=1.0,
        wspace=0.05,
        hspace=0.6,
    )

    df_twostim_55 = am_twostim_stim_phaselocking_ot[
        am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 55
    ]
    figure_stimphase_twostim_scatterplot(
        df_twostim_55, axs=axs[0, :].flatten().tolist(), colors=c_fAM_55
    )

    df_twostim_75 = am_twostim_stim_phaselocking_ot[
        am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 75
    ]
    figure_stimphase_twostim_scatterplot(
        df_twostim_75, axs=axs[1, :].flatten().tolist(), colors=c_fAM_75
    )

## Stim Phase Locking Plot Driver Competitor

In [None]:
def figure_stimphase_driver_competitor(
    df: pd.DataFrame, axs: list[Axes], colors: Sequence[ColorType] = c_flat_noise
):

    df = df.set_index("relative_level")

    for k, level in enumerate(np.unique(df.index)):
        ax = axs[k]

        level_rayleigh_p = np.asarray(df.loc[level, "fixedstim_plv_p"])
        level_plv_driver = np.asarray(df.loc[level, "fixedstim_plv"])
        level_plv_competitor = np.asarray(df.loc[level, "varyingstim_plv"])

        # Filter phases:
        mask = level_rayleigh_p < 0.05

        ax.plot(
            level_plv_driver[~mask],
            level_plv_competitor[~mask],
            ls="none",
            marker=".",
            color=".5",
            mec="none",
        )
        ax.plot(
            level_plv_driver[mask],
            level_plv_competitor[mask],
            ls="none",
            marker=".",
            color=colors[0],
            mec="none",
            zorder=-1,
        )
        ax.plot([0, 1], [0, 1], color="k", lw=0.8, ls="--")
        ax.text(
            0,
            1,
            f"{np.mean(mask):.0%}",
            ha="left",
            va="top",
            fontsize=8,
            transform=ax.transAxes
            + ScaledTranslation(
                +1 / 72, -1 / 72, cast(Figure, ax.figure).dpi_scale_trans
            ),
        )

        ax.set_title(f"{level:.0f}", fontsize=10)

        ax.set_xlim(left=0, right=1)
        ax.set_xticks([0, 1], [])
        ax.set_xticks([0.5], minor=True)

        ax.set_ylim(bottom=0, top=1)
        ax.set_yticks([0, 1], [])
        ax.set_yticks([0.5], minor=True)

        if k == 0:
            ax.set_yticklabels(["0", "1"])
            ax.set_xticklabels(["0", "1"])
        if k == 3:
            ax.set_xlabel("VS Driver")


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.4, 1.6)))
    axs, axg = figure_add_axes_group_inch(
        fig,
        2,
        6,
        group_left=0.8,
        group_top=0.3,
        individual_width=(2.4 - 5 * 0.05) / 6,
        individual_height=(2.4 - 5 * 0.05) / 6,
        wspace=0.05,
        hspace=0.2,
    )

    df_twostim_55 = am_twostim_stim_phaselocking_ot[
        am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 55
    ]
    figure_stimphase_driver_competitor(
        df_twostim_55, axs=axs[0, :].flatten().tolist(), colors=c_fAM_55
    )

    df_twostim_75 = am_twostim_stim_phaselocking_ot[
        am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 75
    ]
    figure_stimphase_driver_competitor(
        df_twostim_75, axs=axs[1, :].flatten().tolist(), colors=c_fAM_75
    )

    axs[0, 0].set_xticklabels([])
    axs[0, 3].set_xlabel("")
    for ax in axs[1, :]:
        ax.set_title("")
    axg.set_ylabel("VS Competitor", labelpad=16)

## Gamma Power Boxplot

In [None]:
def figure_gamma_power_boxplot(
    twostim_gamma: pd.DataFrame,
    single_gamma: pd.DataFrame,
    ax: Axes,
    colors: Sequence[ColorType] = c_flat_noise,
    brackets: dict[tuple[int, int], float] | None = None,
    anova_align="right",
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    merged_rlf = twostim_gamma.join(single_gamma, how="inner", rsuffix="_single")
    merged_rlf = merged_rlf.loc[
        (merged_rlf["fixedintensity"] == merged_rlf["intensity"])
    ]
    merged_rlf = merged_rlf.loc[(merged_rlf["fixedazi"] == merged_rlf["azimuth"])]
    merged_rlf = merged_rlf.loc[(merged_rlf["fixedele"] == merged_rlf["elevation"])]

    print("Number of units:", np.unique(merged_rlf.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_rlf.index.values]))),
    )
    merged_rlf.set_index("relative_level", inplace=True)
    # display(merged_rlf_ot)

    merged_rlf["change_in_power"] = (
        merged_rlf["gammapower"] - merged_rlf["gammapower_single"]
    )
    relative_levels = np.unique(merged_rlf.index)
    groupdata = [
        np.asarray(merged_rlf.loc[relative_level, "change_in_power"])
        for relative_level in relative_levels
    ]

    bp = ax.boxplot(
        groupdata,
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
    )

    ax.set_xticklabels(relative_levels.astype(int))

    stats = anova_tukey(
        merged_rlf, val_col="change_in_power", group_col="relative_level"
    )

    for patch, color in zip(bp["boxes"], colors):
        patch: Patch
        patch.set_facecolor(color)

    for median in bp["medians"]:
        median.set(color="white", linewidth=2)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate(groupdata):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker=".",
                markersize=2,
                markeredgewidth=0.0,
                zorder=5,
            )

    if True:  # show_brackets
        brackets_shrink = 1.0
        if brackets is None:
            brackets = {(k, k + 1): 0.8 for k in range(len(stats["tukey"]) - 1)}
            brackets_shrink = 0.8
        for (idx_left, idx_right), y in brackets.items():
            plot_bracket(
                ax,
                left=idx_left + 1,
                right=idx_right + 1,
                text=f"{stats['tukey'].iloc[idx_left, idx_right]:.3f}"[1:],
                y=y,
                shrink=brackets_shrink,
            )

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    if anova_align == "right":
        ax.text(
            1,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="right",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes,
        )
    else:
        ax.text(
            0,
            1,
            f"p = {stats['anova_p']:.3g}",
            horizontalalignment="left",
            verticalalignment="top",
            fontsize=10,
            transform=ax.transAxes
            + ScaledTranslation(+2 / 72, 0 / 72, cast(Figure, ax.figure).dpi_scale_trans),
        )

    ax.axhline(0, color="k", ls="--", zorder=-1)
    ax.set_xlabel("Relative Level [dB]")
    ax.set_ylabel(r"$Change\ in\ \gamma\ Power$")


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.3)))
    ax = figure_add_axes_inch(
        fig,
        top=0.3,
        left=0.8,
        width=2.4,
        height=1.5,
    )

    figure_gamma_power_boxplot(twostim_gamma_power_ot, single_gamma_power_ot, ax=ax)
    ax.yaxis.set_major_locator(MultipleLocator(10))
    ax.set_ylim(-30, +30)
    ax.set_title("OT")
    display(fig)
    plt.close(fig)
    del fig, ax

if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.3)))
    ax = figure_add_axes_inch(
        fig,
        top=0.3,
        left=0.8,
        width=2.4,
        height=1.5,
    )

    figure_gamma_power_boxplot(twostim_gamma_power_icx, single_gamma_power_icx, ax=ax)
    ax.yaxis.set_major_locator(MultipleLocator(10))
    ax.set_ylim(-30, +30)
    ax.set_title("ICx")
    display(fig)
    plt.close(fig)
    del fig, ax


## Correlation Scatterplots

In [None]:
def figure_correlation_scatterplots(
    twostim_ccg: pd.DataFrame,
    twostim_gamma: pd.DataFrame,
    twostim_stim_phaselocking: pd.DataFrame,
    axs: list[Axes],
    caxs: list[Axes] | None = None,
):

    twostim_ccg = twostim_ccg.copy()
    twostim_ccg.set_index("relative_level", append=True, inplace=True)
    twostim_ccg.set_index("fixed_modulation_frequency", append=True, inplace=True)
    twostim_ccg.sort_index(inplace=True)

    twostim_gamma = twostim_gamma.copy()
    twostim_gamma.set_index("relative_level", append=True, inplace=True)
    twostim_gamma.set_index("fixed_modulation_frequency", append=True, inplace=True)
    twostim_gamma.sort_index(inplace=True)
    twostim_ccg = twostim_ccg.join(
        twostim_gamma,
        on=["date", "owl", "channel1", "relative_level", "fixed_modulation_frequency"],
        rsuffix="1",
    )
    twostim_ccg = twostim_ccg.join(
        twostim_gamma,
        on=["date", "owl", "channel2", "relative_level", "fixed_modulation_frequency"],
        rsuffix="2",
    )

    twostim_stim_phaselocking = twostim_stim_phaselocking.copy()
    twostim_stim_phaselocking.set_index("relative_level", append=True, inplace=True)
    twostim_stim_phaselocking.set_index(
        "fixed_modulation_frequency", append=True, inplace=True
    )
    twostim_stim_phaselocking.sort_index(inplace=True)
    twostim_ccg = twostim_ccg.join(
        twostim_stim_phaselocking,
        on=["date", "owl", "channel1", "relative_level", "fixed_modulation_frequency"],
        rsuffix="1",
    )
    twostim_ccg = twostim_ccg.join(
        twostim_stim_phaselocking,
        on=["date", "owl", "channel2", "relative_level", "fixed_modulation_frequency"],
        rsuffix="2",
    )

    # twostim_ccg.index.names
    twostim_ccg["mean_gamma_plv"] = (
        twostim_ccg["gamma_plv"] + twostim_ccg["gamma_plv2"]
    ) / 2
    twostim_ccg["diff_gamma_plv"] = np.abs(
        twostim_ccg["gamma_plv"] - twostim_ccg["gamma_plv2"]
    )
    twostim_ccg["min_gamma_plv"] = np.max(
        [twostim_ccg["gamma_plv"], twostim_ccg["gamma_plv2"]], axis=0
    )

    twostim_ccg["diff_gamma_plv_angle"] = np.abs(
        twostim_ccg["gamma_plv_angle"] - twostim_ccg["gamma_plv_angle2"]
    )

    twostim_ccg["mean_stim_plv"] = (
        twostim_ccg["fixedstim_plv"] + twostim_ccg["fixedstim_plv2"]
    ) / 2
    twostim_ccg["diff_stim_plv_angle"] = np.abs(
        twostim_ccg["fixedstim_plv_angle"] - twostim_ccg["fixedstim_plv_angle2"]
    )

    axs[0].scatter(
        twostim_ccg["mean_gamma_plv"],
        twostim_ccg["xcorr_peak"],
        c="blue",
        marker=".",
        edgecolors="none",
    )
    r_gamma_plv = np.corrcoef(twostim_ccg["mean_gamma_plv"], twostim_ccg["xcorr_peak"])[
        0, 1
    ]
    axs[0].set_xlabel("VS Gamma LFP")
    axs[0].text(
        0.1,
        1,
        f"R = {r_gamma_plv:.2f}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=axs[0].transAxes,
    )

    axs[1].scatter(
        twostim_ccg["mean_stim_plv"],
        twostim_ccg["xcorr_peak"],
        c="red",
        marker=".",
        edgecolors="none",
    )
    r_stim_plv = np.corrcoef(twostim_ccg["mean_stim_plv"], twostim_ccg["xcorr_peak"])[
        0, 1
    ]
    axs[1].set_xlabel("VS Stim $f_{AM}$")
    axs[1].text(
        0.1,
        1,
        f"R = {r_stim_plv:.2f}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=axs[1].transAxes,
    )

    s = np.argsort(np.asarray(twostim_ccg["mean_gamma_plv"]))
    h2 = axs[2].scatter(
        twostim_ccg["diff_gamma_plv_angle"].values[s],
        twostim_ccg["xcorr_peak"].values[s],
        c=twostim_ccg["mean_gamma_plv"].values[s],
        cmap="jet",
        marker=".",
        edgecolors="none",
        vmin=0.0,
        vmax=1.0,
    )
    r_diff_gamma_plv_angle = np.corrcoef(
        twostim_ccg["diff_gamma_plv_angle"].values[s],
        twostim_ccg["xcorr_peak"].values[s],
    )[0, 1]
    axs[2].set_xlabel("Gamma phase diff.")
    cb2 = plt.colorbar(
        h2,
        cax=caxs[0] if caxs is not None else None,
        label="VS Gamma LFP",
    )
    axs[2].text(
        0.1,
        1,
        f"R = {r_diff_gamma_plv_angle:.2f}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=axs[2].transAxes,
    )

    s = np.argsort(np.asarray(twostim_ccg["mean_stim_plv"]))
    h3 = axs[3].scatter(
        twostim_ccg["diff_stim_plv_angle"].values[s],
        twostim_ccg["xcorr_peak"].values[s],
        c=twostim_ccg["mean_stim_plv"].values[s],
        cmap="jet",
        marker=".",
        edgecolors="none",
        vmin=0.0,
        vmax=1.0,
    )
    r_diff_stim_plv_angle = np.corrcoef(
        twostim_ccg["diff_stim_plv_angle"].values[s],
        twostim_ccg["xcorr_peak"].values[s],
    )[0, 1]
    axs[3].set_xlabel("Stim phase diff.")
    cb3 = plt.colorbar(
        h3,
        cax=caxs[1] if caxs is not None else None,
        label="VS Stim $f_{AM}$",
    )
    axs[3].text(
        0.1,
        1,
        f"R = {r_diff_stim_plv_angle:.2f}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=axs[3].transAxes,
    )

    for ax in axs:
        ax.set_ylim(bottom=0, top=0.0006)
        ax.set_xlim(left=0)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    def formatter_y(x, pos):
        if x > 0:
            return f"${{{x*1e4:.0f}}}$"
        else:
            return "0"

    axs[0].set_yticks([0, 5e-4])
    axs[0].yaxis.set_major_formatter(FuncFormatter(formatter_y))
    axs[0].set_yticks(np.array([1, 2, 3, 4, 6, 7]) * 1e-4, minor=True)
    axs[0].text(
        0,
        1,
        "$10^{-4}$",
        ha="right",
        va="center",
        fontsize=8,
        transform=axs[0].transAxes
        + ScaledTranslation(-2 / 72, 0, cast(Figure, axs[0].figure).dpi_scale_trans),
    )

    for ax in axs[1:]:
        ax.set_yticks([])
        ax.set_yticks(np.array([1, 2, 3, 4, 5, 6, 7]) * 1e-4, minor=True)


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(6.80, 1.8)))

    ax_c = figure_add_axes_inch(
        fig, left=0.8, width=1.0, top=0.2, height=1.0, label="C"
    )

    ax_d = figure_add_axes_inch(
        fig, left=2.0, width=1.0, top=0.2, height=1.0, label="D"
    )

    ax_e = figure_add_axes_inch(
        fig, left=3.2, width=1.0, top=0.2, height=1.0, label="E"
    )

    ax_e_cb = figure_add_axes_inch(
        fig, left=4.2 + 0.05, width=0.05, top=0.2, height=1.0
    )

    ax_f = figure_add_axes_inch(
        fig, left=5.1, width=1.0, top=0.2, height=1.0, label="F"
    )

    ax_f_cb = figure_add_axes_inch(
        fig, left=6.1 + 0.05, width=0.05, top=0.2, height=1.0
    )

    figure_correlation_scatterplots(
        am_twostim_ccg_ot,
        am_twostim_gamma_power_ot,
        am_twostim_stim_phaselocking_ot,
        axs=[ax_c, ax_d, ax_e, ax_f],
        caxs=[ax_e_cb, ax_f_cb],
    )

    ax_c.set_ylabel("Spike Train\nSynchrony")
    ax_c.yaxis.set_label_coords(-0.3 / 1.0, 0.5)

## Elevation Tunings

In [None]:
def figure_best_elevation_hist(tunings_ot, tunings_icx, ax: list[Axes]):

    best_ele_bins = np.linspace(-95, 95, num=20, endpoint=True)
    print(best_ele_bins)

    best_ele_counts_ot, _ = np.histogram(tunings_ot["best_ele"].values, best_ele_bins)
    best_ele_counts_icx, _ = np.histogram(tunings_icx["best_ele"].values, best_ele_bins)

    ax[0].bar(
        best_ele_bins[:-1], best_ele_counts_ot, align="edge", width=10, color=c_ot
    )
    ax[0].plot(
        np.median(tunings_ot["best_ele"].values),
        0.8,
        ls="none",
        marker="v",
        color="k",
        markerfacecolor="w",
        transform=blended_transform_factory(ax[0].transData, ax[0].transAxes),
        zorder=+20,
    )
    ax[0].text(
        0,
        1,
        "OT",
        fontsize=10,
        ha="left",
        va="top",
        color="k",
        transform=ax[0].transAxes
        + ScaledTranslation(
            +2 / 72, -4 / 72, cast(Figure, ax[0].figure).dpi_scale_trans
        ),
    )

    ax[1].bar(
        best_ele_bins[:-1], best_ele_counts_icx, align="edge", width=10, color=c_icx
    )
    ax[1].plot(
        np.median(tunings_icx["best_ele"].values),
        0.8,
        ls="none",
        marker="v",
        color="k",
        markerfacecolor="w",
        transform=blended_transform_factory(ax[1].transData, ax[1].transAxes),
        zorder=+20,
    )
    ax[1].text(
        0,
        1,
        "ICx",
        fontsize=10,
        ha="left",
        va="top",
        color="k",
        transform=ax[1].transAxes
        + ScaledTranslation(
            +2 / 72, -4 / 72, cast(Figure, ax[1].figure).dpi_scale_trans
        ),
    )

    ax[0].set_xlim(left=-95, right=+95)
    ax[0].set_ylim(top=35)
    ax[0].yaxis.set_major_locator(MultipleLocator(20))
    ax[0].yaxis.set_minor_locator(MultipleLocator(10))
    ax[0].xaxis.set_minor_locator(MultipleLocator(30))
    ax[0].set_xticks([], minor=False)
    ax[0].set_zorder(+5)
    ax[0].spines["top"].set_visible(False)
    ax[0].spines["right"].set_visible(False)

    ax[1].set_xlim(left=-95, right=+95)
    ax[1].set_ylim(top=35)
    ax[1].yaxis.set_major_locator(MultipleLocator(20))
    ax[1].yaxis.set_minor_locator(MultipleLocator(10))
    ax[1].xaxis.set_major_locator(MultipleLocator(90))
    ax[1].xaxis.set_minor_locator(MultipleLocator(30))
    ax[1].spines["top"].set_visible(False)
    ax[1].spines["right"].set_visible(False)

    ax[0].axvline(0, lw=0.8, color="k", zorder=-10)
    ax[1].axvline(0, lw=0.8, color="k", zorder=-10)

In [None]:
def figure_best_elevation_boxplots(tunings_ot, tunings_icx, ax: Axes):

    best_ele_std_ot = np.asarray(
        tunings_ot.groupby(["date", "owl"])["best_ele"].std(),
    )
    best_ele_std_icx = np.asarray(
        tunings_icx.groupby(["date", "owl"])["best_ele"].std(),
    )
    groupdata = [best_ele_std_ot, best_ele_std_icx]

    bp = ax.boxplot(
        groupdata,
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
        widths=0.6,
    )

    ax.set_xticklabels(["OT", "ICx"])

    stats = scipy.stats.wilcoxon(best_ele_std_ot, best_ele_std_icx)

    for patch, color in zip(bp["boxes"], [c_ot, c_icx]):
        patch: Patch
        patch.set_facecolor(color)
        patch.set_linewidth(0.0)

    for median in bp["medians"]:
        median.set(color="k", linewidth=1)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate(groupdata):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker="D",
                markersize=3,
                markeredgewidth=0.8,
                markerfacecolor="w",
                zorder=5,
            )

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    ax.text(
        0.1,
        1,
        f"p = {stats.pvalue:.3g}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=ax.transAxes,
    )

    ax.set_ylim(bottom=0)

In [None]:
def figure_elevation_width_hist(tunings_ot, tunings_icx, ax: list[Axes]):

    width_ele_bins = np.linspace(0, 180, num=19, endpoint=True)
    print(width_ele_bins)

    width_ele_counts_ot, _ = np.histogram(
        tunings_ot["ele_width"].values, width_ele_bins
    )
    width_ele_counts_icx, _ = np.histogram(
        tunings_icx["ele_width"].values, width_ele_bins
    )

    ax[0].bar(
        width_ele_bins[:-1], width_ele_counts_ot, align="edge", width=10, color=c_ot
    )
    ax[0].plot(
        np.median(tunings_ot["ele_width"].values),
        0.8,
        ls="none",
        marker="v",
        color="k",
        markerfacecolor="w",
        transform=blended_transform_factory(ax[0].transData, ax[0].transAxes),
        zorder=+20,
    )

    ax[1].bar(
        width_ele_bins[:-1], width_ele_counts_icx, align="edge", width=10, color=c_icx
    )
    ax[1].plot(
        np.median(tunings_icx["ele_width"].values),
        0.8,
        ls="none",
        marker="v",
        color="k",
        markerfacecolor="w",
        transform=blended_transform_factory(ax[1].transData, ax[1].transAxes),
        zorder=+20,
    )

    ax[0].set_xlim(left=0, right=+180)
    ax[0].set_ylim(top=55)
    ax[0].yaxis.set_major_locator(MultipleLocator(30))
    ax[0].yaxis.set_minor_locator(MultipleLocator(10))
    ax[0].xaxis.set_minor_locator(MultipleLocator(30))
    ax[0].set_xticks([], minor=False)
    ax[0].set_zorder(+5)
    ax[0].spines["top"].set_visible(False)
    ax[0].spines["right"].set_visible(False)

    ax[1].set_xlim(left=0, right=+180)
    ax[1].set_ylim(top=55)
    ax[1].yaxis.set_major_locator(MultipleLocator(30))
    ax[1].yaxis.set_minor_locator(MultipleLocator(10))
    ax[1].xaxis.set_major_locator(MultipleLocator(60))
    ax[1].xaxis.set_minor_locator(MultipleLocator(30))
    ax[1].spines["top"].set_visible(False)
    ax[1].spines["right"].set_visible(False)

    ax[0].axvline(0, lw=0.8, color="k", zorder=-10)
    ax[1].axvline(0, lw=0.8, color="k", zorder=-10)

In [None]:
def figure_elevation_width_boxplots(tunings_ot, tunings_icx, ax: Axes):

    ele_width_std_ot = tunings_ot.groupby(["date", "owl"])["ele_width"].std()
    ele_width_std_icx = tunings_icx.groupby(["date", "owl"])["ele_width"].std()
    groupdata = [ele_width_std_ot, ele_width_std_icx]

    bp = ax.boxplot(
        groupdata,
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
        widths=0.6,
    )

    ax.set_xticklabels(["OT", "ICx"])

    stats = scipy.stats.wilcoxon(ele_width_std_ot, ele_width_std_icx)

    for patch, color in zip(bp["boxes"], [c_ot, c_icx]):
        patch: Patch
        patch.set_facecolor(color)
        patch.set_linewidth(0.0)

    for median in bp["medians"]:
        median.set(color="k", linewidth=1)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate(groupdata):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker="D",
                markersize=3,
                markeredgewidth=0.8,
                markerfacecolor="w",
                zorder=5,
            )

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    ax.text(
        0.1,
        1,
        f"p = {stats.pvalue:.3g}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=ax.transAxes,
    )

    ax.set_ylim(bottom=0)

In [None]:
def figure_elevation_signalcorr_hist(signalcorr_ot, signalcorr_icx, ax: list[Axes]):

    signal_corr_bins = np.linspace(-1, 1, num=41, endpoint=True)

    signalcorr_ele_counts_ot, _ = np.histogram(
        signalcorr_ot["signal_corr"].values, signal_corr_bins
    )
    signalcorr_ele_counts_icx, _ = np.histogram(
        signalcorr_icx["signal_corr"].values, signal_corr_bins
    )

    ax[0].bar(
        signal_corr_bins[:-1],
        signalcorr_ele_counts_ot,
        align="edge",
        width=2 / 40,
        color=c_ot,
    )
    ax[0].plot(
        np.median(signalcorr_ot["signal_corr"].values),
        0.8,
        ls="none",
        marker="v",
        color="k",
        markerfacecolor="w",
        transform=blended_transform_factory(ax[0].transData, ax[0].transAxes),
        zorder=+20,
    )

    ax[1].bar(
        signal_corr_bins[:-1],
        signalcorr_ele_counts_icx,
        align="edge",
        width=2 / 40,
        color=c_icx,
    )
    ax[1].plot(
        np.median(signalcorr_icx["signal_corr"].values),
        0.8,
        ls="none",
        marker="v",
        color="k",
        markerfacecolor="w",
        transform=blended_transform_factory(ax[1].transData, ax[1].transAxes),
        zorder=+20,
    )

    ax[0].set_xlim(left=-1, right=+1)
    ax[0].set_ylim(top=55)
    ax[0].yaxis.set_major_locator(MultipleLocator(30))
    ax[0].yaxis.set_minor_locator(MultipleLocator(10))
    ax[0].xaxis.set_minor_locator(MultipleLocator(0.5))
    ax[0].set_xticks([], minor=False)
    ax[0].set_zorder(+5)
    ax[0].spines["top"].set_visible(False)
    ax[0].spines["right"].set_visible(False)

    ax[1].set_xlim(left=-1, right=+1)
    ax[1].set_ylim(top=55)
    ax[1].yaxis.set_major_locator(MultipleLocator(30))
    ax[1].yaxis.set_minor_locator(MultipleLocator(10))
    ax[1].xaxis.set_major_locator(MultipleLocator(1))
    ax[1].xaxis.set_minor_locator(MultipleLocator(0.5))
    ax[1].spines["top"].set_visible(False)
    ax[1].spines["right"].set_visible(False)

    ax[0].axvline(0, lw=0.8, color="k", zorder=-10)
    ax[1].axvline(0, lw=0.8, color="k", zorder=-10)

In [None]:
def figure_elevation_signalcorr_boxplots(
    signalcorr_ot: pd.DataFrame, signalcorr_icx: pd.DataFrame, ax: Axes
):
    signalcorr_std_ot = np.asarray(
        signalcorr_ot.groupby(["date", "owl"])["signal_corr"].std()
    )
    signalcorr_std_icx = np.asarray(
        signalcorr_icx.groupby(["date", "owl"])["signal_corr"].std()
    )
    groupdata = [signalcorr_std_ot, signalcorr_std_icx]

    bp = ax.boxplot(
        groupdata,
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
        widths=0.6,
    )

    ax.set_xticklabels(["OT", "ICx"])

    stats = scipy.stats.wilcoxon(signalcorr_std_ot, signalcorr_std_icx)

    for patch, color in zip(bp["boxes"], [c_ot, c_icx]):
        patch: Patch
        patch.set_facecolor(color)
        patch.set_linewidth(0.0)

    for median in bp["medians"]:
        median.set(color="k", linewidth=1)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate(groupdata):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker="D",
                markersize=3,
                markeredgewidth=0.8,
                markerfacecolor="w",
                zorder=5,
            )

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    ax.text(
        0.1,
        1,
        f"p = {stats.pvalue:.3g}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=ax.transAxes,
    )

    ax.set_ylim(bottom=0, top=0.6)

## Latency Subpopulations

In [None]:
##get subpopulation counts
am55_single = am_single_rlf_ot[am_single_rlf_ot['modulation_frequency'] == 55]
am55_twostim = am_twostim_rlf_ot[am_twostim_rlf_ot['fixed_modulation_frequency'] == 55]

df_ts = am55_twostim.copy()
df_s = am55_single.copy()

merged_rlf = df_ts.join(df_s, how="inner", rsuffix="_single")
merged_rlf = merged_rlf.loc[
    (merged_rlf["fixedintensity"] == merged_rlf["intensity"])
]
merged_rlf = merged_rlf.loc[(merged_rlf["fixedazi"] == merged_rlf["azimuth"])]
merged_rlf = merged_rlf.loc[(merged_rlf["fixedele"] == merged_rlf["elevation"])]

pl_units = merged_rlf[(merged_rlf['short_latency'])]
npl_units = merged_rlf[(merged_rlf['long_latency'])]

print("Number of short latency units:", np.unique(pl_units.index).size)
print("Number of long latency units:", np.unique(npl_units.index).size)


### PSTH

In [None]:
def psth_by_subpopulation(df_phaselocking_ot, ax: Axes):

    STIM_START = 1.010

    df = df_phaselocking_ot.copy()

    times = np.arange(0, len(df["psth"].values[0])) * 0.001 - STIM_START

    ax.plot(
        times,
        np.mean(df.loc[df["short_latency"], "psth"].values, axis=0),
        color=c_fAM_55[0],
        lw=1.0,
        label="Short latency",
        zorder=5,
    )
    ax.plot(
        times,
        np.mean(df.loc[df["long_latency"], "psth"].values, axis=0),
        color="0.5",
        lw=1.0,
        label="Long latency",
    )

    ax.axvline(
        np.median(df.loc[df["short_latency"], "first_spike_latency"].values),
        color="k",
        ls="--",
        zorder=-10,
    )
    ax.axvline(
        np.median(df.loc[df["long_latency"], "first_spike_latency"].values),
        color="k",
        ls=":",
        zorder=-10,
    )
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_xlim(-0.02, 0.32)
    ax.set_ylim(top=0.35)
    ax.set_xlabel("Time rel Stimulus Onset [s]")
    ax.set_ylabel("Response [spks/bin]")
    ax.legend(frameon=False, loc="upper right")


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.10)))

    ax = figure_add_axes_inch(
        fig,
        top=0.1,
        left=0.8,
        width=2.4,
        height=1.5,
        label="B",
    )

    df_latency = am_single_stim_phaselocking_ot[
        (am_single_stim_phaselocking_ot["modulation_frequency"] == 55)
        & (am_single_stim_phaselocking_ot["intensity"] == -5)
    ]

    psth_by_subpopulation(df_latency, ax=ax)

### Latency Histogram

In [None]:
def figure_latency_histogram(df_phaselocking_ot, ax: Axes):

    df = df_phaselocking_ot.copy()

    latency_bins = np.linspace(0, 35, 36)

    latency_short = (
        (df.loc[df["short_latency"], "first_spike_latency"] * 1000).round().values
    )
    latency_long = (
        (df.loc[df["long_latency"], "first_spike_latency"] * 1000).round().values
    )

    print(
        f"Short Latency Median: {np.median(latency_short)} +/- {scipy.stats.sem(latency_short)}"
    )
    print(
        f"Long Latency Median: {np.median(latency_long)} +/- {scipy.stats.sem(latency_long)}"
    )

    ax.hist(latency_short, latency_bins, color=c_fAM_55[0], label="Short latency")
    ax.hist(
        latency_long,
        latency_bins,
        color="0.5",
        label="Long latency",
        alpha=0.5,
    )

    ax.axvline(14, color="k", ls=":")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_xlabel("Time rel Stimulus Onset [s]")
    ax.set_ylabel("# Units")
    ax.legend(frameon=False, loc="upper right")


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.10)))

    ax = figure_add_axes_inch(
        fig,
        top=0.1,
        left=0.8,
        width=2.4,
        height=1.5,
        label="B",
    )
    df_latency = am_single_stim_phaselocking_ot[
        (am_single_stim_phaselocking_ot["modulation_frequency"] == 55)
        & (am_single_stim_phaselocking_ot["intensity"] == -5)
    ]
    figure_latency_histogram(
        df_latency,
        ax=ax,
    )

### VS Stimulus

In [None]:
def figure_latency_plv_histogram(df_phaselocking_ot: pd.DataFrame, ax: Axes, colors):

    df = df_phaselocking_ot.copy()

    latency_short = np.asarray(df.loc[df["short_latency"], "singlestim_plv"])
    latency_long = np.asarray(df.loc[df["long_latency"], "singlestim_plv"])
    bp = ax.boxplot(
        [latency_short, latency_long],
        labels=["Short", "Long"],
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
    )
    for patch, color in zip(bp["boxes"], colors):
        patch: Patch
        patch.set_facecolor(color)

    for median in bp["medians"]:
        median.set(color="white", linewidth=2)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate([latency_short, latency_long]):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker=".",
                markersize=2,
                markeredgewidth=0.0,
                zorder=5,
            )

    statistics, pvalue = scipy.stats.mannwhitneyu(latency_short, latency_long)
    ax.text(
        0.95,
        0.95,
        f"p = {pvalue:.3g}",
        horizontalalignment="right",
        verticalalignment="top",
        fontsize=10,
        transform=ax.transAxes,
    )
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_ylabel("VS Stim $f_{AM}$")
    ax.set_xlabel("Subpopulation by Latency")


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.10)))

    ax = figure_add_axes_inch(
        fig,
        top=0.1,
        left=0.8,
        width=2.4,
        height=1.5,
        label="B",
    )
    df_latency = am_single_stim_phaselocking_ot[
        (am_single_stim_phaselocking_ot["modulation_frequency"] == 55)
        & (am_single_stim_phaselocking_ot["intensity"] == -5)
    ]
    df_latency = df_latency.join(
        single_rlf_ot[single_rlf_ot["intensity"] == -5][
            ["short_latency", "long_latency", "first_spike_latency"]
        ],
        how="inner",
        rsuffix="_flat",
    )
    figure_latency_plv_histogram(
        df_latency,
        ax=ax,
        colors=[c_fAM_55[0], "0.5"],
    )

### VS Gamma

In [None]:
def figure_latency_gamma_plv_histogram(
    df_phaselocking_ot: pd.DataFrame, ax: Axes, colors
):

    df = df_phaselocking_ot.copy()

    latency_short = np.asarray(df.loc[df["short_latency"], "gamma_plv"])
    latency_long = np.asarray(df.loc[df["long_latency"], "gamma_plv"])

    bp = ax.boxplot(
        [latency_short, latency_long],
        labels=["Short", "Long"],
        patch_artist=True,
        notch=False,
        showfliers=False,
        whis=(5, 95),
    )

    for patch, color in zip(bp["boxes"], colors):
        patch: Patch
        patch.set_facecolor(color)

    for median in bp["medians"]:
        median.set(color="white", linewidth=2)

    if True:  # show_individual_data
        width = 0.7
        for k, groupvalues in enumerate([latency_short, latency_long]):
            pos = k + 1
            y = np.sort(groupvalues)
            vals_x = np.linspace(
                pos - width / 3,
                pos + width / 3,
                y.size,
                endpoint=True,
            )
            ax.plot(
                vals_x,
                y,
                color="k",
                ls="None",
                marker=".",
                markersize=2,
                markeredgewidth=0.0,
                zorder=5,
            )

    statistics, pvalue = scipy.stats.mannwhitneyu(latency_short, latency_long)
    ax.text(
        0.95,
        0.95,
        f"p = {pvalue:.3g}",
        horizontalalignment="right",
        verticalalignment="top",
        fontsize=10,
        transform=ax.transAxes,
    )
    # ax.axvline(14, color="k", ls=":")
    # ax.set_xlim(-0.05, 0.35)

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_ylabel("VS Gamma LFP")
    ax.set_xlabel("Subpopulation by Latency")


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.10)))
    ax = figure_add_axes_inch(
        fig,
        top=0.1,
        left=0.8,
        width=2.4,
        height=1.5,
        label="B",
    )

    df_latency = am_single_stim_phaselocking_ot[
        (am_single_stim_phaselocking_ot["modulation_frequency"] == 55)
        & (am_single_stim_phaselocking_ot["intensity"] == -5)
    ]
    print(df_latency.shape)

    df_latency = df_latency.join(
        am_single_gamma_power_ot[
            (am_single_gamma_power_ot["modulation_frequency"] == 55)
            & (am_single_gamma_power_ot["intensity"] == -5)
        ],
        how="left",
        rsuffix="_gamma",
    )
    print(df_latency.shape)

    figure_latency_gamma_plv_histogram(
        df_latency,
        ax=ax,
        colors=[c_fAM_55[0], "0.5"],
    )
    ax.set_ylim(0, 0.9)

### Phase differences

In [None]:
def phase_properties(phases):
    ## convert to complex numbers
    complex_phases = np.cos(phases) + np.sin(phases) * 1j
    mean_vector = np.mean(complex_phases)
    vector_strength = np.abs(mean_vector)
    mean_angle = np.angle(mean_vector)
    total_n = len(phases)
    rayleigh_r = total_n * vector_strength
    pval = np.exp(
        np.sqrt(1 + 4 * total_n + 4 * (total_n**2 - rayleigh_r**2)) - (1 + 2 * total_n)
    )
    return {"vector_strength": vector_strength, "mean_angle": mean_angle, "pval": pval}


def merge_phase_data(
    singlestim_phase_locking: pd.DataFrame,
    twostim_phase_locking: pd.DataFrame,
    axs: list[list[Axes]],
):
    df_ts = twostim_phase_locking.copy()
    df_single = singlestim_phase_locking.copy()

    merged_df = df_ts.join(df_single, how="inner", rsuffix="_single")
    merged_df = merged_df.loc[
        (merged_df["fixedintensity"] == merged_df["intensity"])
        & (merged_df["fixedazi"] == merged_df["azimuth"])
        & (merged_df["fixedele"] == merged_df["elevation"])
    ]

    pl_units = merged_df[merged_df["short_latency"]].set_index("relative_level")
    npl_units = merged_df[merged_df["long_latency"]].set_index("relative_level")

    plot_levels = [-15, 0, 10]
    # plot_levels = [-15, -10, -5, 0, +5, +10]
    pl_groupdata = [
        np.asarray(pl_units.loc[relative_level, "stim_phase_differences"])
        for relative_level in plot_levels
    ]
    npl_groupdata = [
        np.asarray(npl_units.loc[relative_level, "stim_phase_differences"])
        for relative_level in plot_levels
    ]

    phase_bins = np.linspace(-np.pi, np.pi, 13)
    for g, (groupdata, color) in enumerate(
        [(pl_groupdata, c_fAM_55[0]), (npl_groupdata, ".5")]
    ):
        for k, subdata in enumerate(groupdata):
            all_vs = np.array(
                [
                    phase_properties(unit_phases)["vector_strength"]
                    for unit_phases in subdata
                ]
            )
            hdata = np.vstack(
                [np.histogram(unit_phases, phase_bins)[0] for unit_phases in subdata]
            )
            ax: Axes = axs[g][k]
            # ax.set_title(plot_levels[k])
            if k == 0:
                ax.text(
                    0.5,
                    1.0,
                    f"\nRel. Level",
                    transform=ax.transAxes
                    + ScaledTranslation(
                        0 / 72, -4 / 72, cast(Figure, ax.figure).dpi_scale_trans
                    ),
                    ha="center",
                    va="top",
                    fontsize=10,
                )

            ax.text(
                0.5,
                1.0,
                f"{plot_levels[k]} dB",
                transform=ax.transAxes
                + ScaledTranslation(
                    0 / 72, -4 / 72, cast(Figure, ax.figure).dpi_scale_trans
                ),
                ha="center",
                va="top",
                fontsize=10,
            )
            ax.errorbar(
                phase_bins[:-1],
                np.mean(hdata, axis=0),
                yerr=scipy.stats.sem(hdata, axis=0),
                color=color,
                ls="-",
            )
            ax.text(
                0.5,
                0.0,
                f"VS {np.mean(all_vs):.2f}",
                transform=ax.transAxes
                + ScaledTranslation(
                    0 / 72, +4 / 72, cast(Figure, ax.figure).dpi_scale_trans
                ),
                ha="center",
                fontsize=10,
            )
            ax.set_xticks(
                [-np.pi, 0, +np.pi], labels=["-π", "0", "π"] if k == 0 else []
            )
            ax.set_yticks(
                [0, 20, 40, 60], labels=["0", "20", "40", "60"] if k == 0 else []
            )

            ax.set_ylim(bottom=0, top=75)
            ax.set_xlim(-np.pi, +np.pi)

            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)

    axs[0][1].set_title("Short Latency Units")
    axs[1][1].set_title("Long Latency Units")
    axs[0][0].set_ylabel("# Spikes")
    axs[1][0].set_ylabel("# Spikes")
    axs[0][0].set_xlabel("Phase Difference")
    axs[1][0].set_xlabel("Phase Difference")


if do_usage:
    df_single_latency = am_single_stim_phaselocking_ot[
        (am_single_stim_phaselocking_ot["modulation_frequency"] == 55)
    ]
    df_twostim_latency = am_twostim_stim_phaselocking_ot[
        (am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 55)
    ]

    fig = cast(Figure, plt.figure(figsize=(3.4, 4.0)))
    axs, axg = figure_add_axes_group_inch(
        fig,
        2,
        3,
        group_left=0.8,
        group_top=0.3,
        individual_width=0.8,
        individual_height=1.2,
        wspace=0.05,
        hspace=0.75,
    )
    merge_phase_data(
        df_single_latency, df_twostim_latency, axs=[axs[0].tolist(), axs[1].tolist()]
    )

### Competition and Spike Response

In [None]:
def competition_spkresp(
    single_rlf_ot: pd.DataFrame,
    ts_rlf_ot: pd.DataFrame,
    color: ColorType,
    ax: Axes,
):
    single_rlf = single_rlf_ot.copy()
    twostim_rlf = ts_rlf_ot.copy()

    merged_rlf = twostim_rlf.join(single_rlf, how="inner", rsuffix="_single")
    merged_rlf = merged_rlf.loc[
        (merged_rlf["fixedintensity"] == merged_rlf["intensity"])
    ]
    merged_rlf = merged_rlf.loc[(merged_rlf["fixedazi"] == merged_rlf["azimuth"])]
    merged_rlf = merged_rlf.loc[(merged_rlf["fixedele"] == merged_rlf["elevation"])]

    print("Number of units:", np.unique(merged_rlf.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_rlf.index.values]))),
    )
    merged_rlf.set_index("relative_level", inplace=True)

    merged_rlf["change_in_response"] = (
        merged_rlf["resp"] - merged_rlf["resp_single"]
    ) / merged_rlf["resp_single"]
    relative_levels = np.unique(merged_rlf.index)

    pl_units = merged_rlf[merged_rlf["short_latency"]]
    npl_units = merged_rlf[merged_rlf["long_latency"]]

    pl_means = pl_units.groupby("relative_level")["change_in_response"].mean()
    npl_means = npl_units.groupby("relative_level")["change_in_response"].mean()

    pl_sems = pl_units.groupby("relative_level")["change_in_response"].sem()
    npl_sems = npl_units.groupby("relative_level")["change_in_response"].sem()

    h1 = ax.errorbar(
        pl_means.index.values, pl_means, pl_sems, color=color, label="Short Lat."
    )
    h2 = ax.errorbar(
        npl_means.index.values, npl_means, npl_sems, color="0.5", label="Long Lat."
    )

    print(pl_means.index)

    t_test_results = t_test_ind(pl_units, npl_units, val_col="change_in_response")
    display(t_test_results)
    for relative_level, row in t_test_results.iterrows():
        relative_level = cast(float, relative_level)
        if True or row["p"] < 0.05:
            ax.text(
                relative_level,
                0.0,
                f"{row['p']:.2g}".removeprefix("0"),
                transform=blended_transform_factory(ax.transData, ax.transAxes)
                + ScaledTranslation(
                    0 / 72, +2 / 72, cast(Figure, ax.figure).dpi_scale_trans
                ),
                fontsize=8,
                ha="center",
                va="bottom",
            )

    ax.axhline(0, color="k", ls=":")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_xlabel("Relative Level [dB]")
    ax.set_ylabel("Change in Spike Rate")
    # handles, labels = ax.get_legend_handles_labels()

    ax.legend(
        handles=[h1[0], h2[0]],
        labels=[h1.get_label(), h2.get_label()],
        frameon=False,
        loc="upper right",
    )
    ax.set_xlim(left=relative_levels[0] - 2, right=relative_levels[-1] + 2)
    ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.0%}"))
    ax.set_xticks(
        relative_levels.astype(int), labels=map(str, relative_levels.astype(int))
    )


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.30)))
    ax = figure_add_axes_inch(
        fig,
        top=0.3,
        left=0.8,
        width=2.4,
        height=1.5,
    )
    competition_spkresp(
        am_single_rlf_ot[(am_single_rlf_ot["modulation_frequency"] == 55)],
        am_twostim_rlf_ot[(am_twostim_rlf_ot["fixed_modulation_frequency"] == 55)],
        color=c_fAM_55[0],
        ax=ax,
    )
    ax.set_ylim(-0.55, 0.30)
    ax.set_title("AM Noise ($f_{AM} = 55 Hz$)")
    display(fig)
    plt.close(fig)
    del ax, fig

if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.30)))
    ax = figure_add_axes_inch(
        fig,
        top=0.3,
        left=0.8,
        width=2.4,
        height=1.5,
    )
    competition_spkresp(single_rlf_ot, twostim_rlf_ot, color=c_flat_noise[0], ax=ax)
    ax.set_ylim(-0.55, 0.30)
    ax.set_title("Flat Noise")
    display(fig)
    plt.close(fig)
    del ax, fig

### Competition and Synchrony

In [None]:
def figure_xcorr_ccg_subpop(
    twostim_ccg: pd.DataFrame,
    single_ccg: pd.DataFrame,
    single_rlf: pd.DataFrame,
    axs: Sequence[Axes],
    color: ColorType = c_flat_noise[0],
):
    """Boxplot of Change in Spike Rate as function of Relative Level"""

    merged_ccg = (
        twostim_ccg.join(
            single_rlf,
            on=["date", "owl", "channel1"],
            how="inner",
            rsuffix="_rlf",
        )
        .join(
            single_rlf,
            on=["date", "owl", "channel2"],
            how="left",
            rsuffix="_rlf2",
        )
        .join(
            single_ccg,
            how="inner",
            rsuffix="_single",
        )
    )

    merged_ccg = merged_ccg.loc[
        (merged_ccg["fixedintensity"] == merged_ccg["intensity_single"])
        & (merged_ccg["fixedintensity"] == merged_ccg["intensity"])
        & (merged_ccg["fixedazi"] == merged_ccg["azimuth"])
        & (merged_ccg["fixedele"] == merged_ccg["elevation"])
        & (merged_ccg["fixedazi"] == merged_ccg["azimuth_single"])
        & (merged_ccg["fixedele"] == merged_ccg["elevation_single"])
        & (merged_ccg["intensity"] == merged_ccg["intensity_rlf2"])
        & (merged_ccg["azimuth"] == merged_ccg["azimuth_rlf2"])
        & (merged_ccg["elevation"] == merged_ccg["elevation_rlf2"])
    ]

    pl_units = merged_ccg[
        merged_ccg["short_latency"] & merged_ccg["short_latency_rlf2"]
    ]
    npl_units = merged_ccg[merged_ccg["long_latency"] & merged_ccg["long_latency_rlf2"]]

    print("Number of unit pairs:", np.unique(merged_ccg.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_ccg.index.values]))),
    )
    print("SHORT Number of unit pairs:", np.unique(pl_units.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in pl_units.index.values]))),
    )
    print("LONG Number of unit pairs:", np.unique(npl_units.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in npl_units.index.values]))),
    )

    pl_units.set_index("relative_level", inplace=True)
    npl_units.set_index("relative_level", inplace=True)

    ### LONG
    # Before reindexing by relative_level:
    single_mean_peak = np.max(
        np.mean(
            np.vstack(
                cast(
                    Sequence[float],
                    npl_units["ccg_single"].groupby(npl_units.index.names).first(),
                )
            ),
            axis=0,
        )
    )
    single_std_peak = np.std(
        np.max(
            np.vstack(
                cast(
                    Sequence[float],
                    npl_units["ccg_single"].groupby(npl_units.index.names).first(),
                )
            ),
            axis=0,
        )
    )
    print(f"mean: {single_mean_peak = :.4g} std: {single_std_peak}")

    psth_len = (npl_units.iloc[0]["ccg_single"].size + 1) / 2
    lags = scipy.signal.correlation_lags(psth_len, psth_len) / 1000
    lags_mask = np.abs(lags) <= 0.05

    relative_levels = np.unique(npl_units.index)

    for k, relative_level in enumerate(relative_levels):
        ax = axs[k]
        mean_ccg = np.mean(
            np.vstack(cast(Sequence[np.ndarray], npl_units.loc[relative_level, "ccg"])),
            axis=0,
        )
        n_ccg = cast(
            pd.Series, npl_units.loc[relative_level, "xcorr_peak_single"]
        ).count()

        print(
            f"{relative_level:5}"
            f" | {np.mean(npl_units.loc[relative_level, 'xcorr_peak_single']):.4g} {n_ccg = }"  # type: ignore
            f" | {np.mean(npl_units.loc[relative_level, 'xcorr_peak']):.4g} {n_ccg = }"  # type: ignore
            f" | peak curve: {np.max(mean_ccg)}"
        )
        ax.plot(lags[lags_mask], mean_ccg[lags_mask], color=".5", ls="-", lw=1)
        ax.text(
            0.5,
            1.05,
            f"{n_ccg}",
            ha="center",
            va="bottom",
            fontsize=8,
            transform=axs[k].transAxes,
        )

    ### SHORT
    # Before reindexing by relative_level:
    single_mean_peak = np.max(
        np.mean(
            np.vstack(
                list(pl_units["ccg_single"].groupby(pl_units.index.names).first())
            ),
            axis=0,
        )
    )
    single_std_peak = np.std(
        np.max(
            np.vstack(
                list(pl_units["ccg_single"].groupby(pl_units.index.names).first())
            ),
            axis=0,
        )
    )
    print(f"mean: {single_mean_peak = :.4g} std: {single_std_peak}")

    psth_len = (pl_units.iloc[0]["ccg_single"].size + 1) / 2
    lags = scipy.signal.correlation_lags(psth_len, psth_len) / 1000
    lags_mask = np.abs(lags) <= 0.05

    relative_levels = np.unique(pl_units.index)

    for k, relative_level in enumerate(relative_levels):
        mean_ccg = np.mean(
            np.vstack(
                cast(
                    Sequence[np.ndarray],
                    cast(pd.Series, pl_units.loc[relative_level, "ccg"]).values,
                )
            ),
            axis=0,
        )
        n_ccg = cast(
            pd.Series, pl_units.loc[relative_level, "xcorr_peak_single"]
        ).count()

        print(
            f"{relative_level:5}"
            f" | {np.mean(pl_units.loc[relative_level, 'xcorr_peak_single']):.4g} {n_ccg = }"  # type: ignore
            f" | {np.mean(pl_units.loc[relative_level, 'xcorr_peak']):.4g} {n_ccg = }"  # type: ignore
            f" | peak curve: {np.max(mean_ccg)}"
        )
        axs[k].plot(lags[lags_mask], mean_ccg[lags_mask], color=color, ls="-", lw=1)
        axs[k].text(
            0.5,
            1,
            f"{n_ccg}",
            ha="center",
            va="top",
            fontsize=8,
            transform=axs[k].transAxes,
            color=color,
        )
        axs[k].spines["top"].set_visible(False)
        axs[k].spines["right"].set_visible(False)

        axs[k].set_ylim(bottom=0, top=0.000125)
        axs[k].set_xlim(left=-0.060, right=0.060)

    axs[0].set_ylabel("Coinc./spk", labelpad=12)

    axs[0].set_yticks([0, 10e-5], labels=["0", "$10^{-4}$"])
    axs[0].yaxis.set_minor_locator(MultipleLocator(5e-5))

    for ax in axs[1:]:
        ax.set_yticks([])
        ax.set_yticks([], minor=True)
        ax.spines["left"].set_visible(False)

    axs[0].set_xticks([-0.05, 0, +0.05], ["-50", "0", "+50"], fontsize=8)
    axs[0].set_xlabel("Lags [ms]", fontsize=8, labelpad=0)
    for ax in axs[1:]:
        ax.set_xticks([-0.05, 0, +0.05], [])


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.4, 1.2)))
    axs, axg = figure_add_axes_group_inch(
        fig,
        1,
        6,
        individual_width=0.4,
        individual_height=0.6,
        wspace=0.0,
    )
    figure_xcorr_ccg_subpop(
        twostim_ccg_ot,
        single_ccg_ot,
        single_rlf_ot,
        axs=axs.flatten().tolist(),
        color=c_flat_noise[0],
    )
    display(fig)
    plt.close(fig)
    del fig, axs, axg

if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.4, 1.2)))
    axs, axg = figure_add_axes_group_inch(
        fig,
        1,
        6,
        individual_width=0.4,
        individual_height=0.6,
        wspace=0.0,
    )
    figure_xcorr_ccg_subpop(
        am_twostim_ccg_ot[(am_twostim_ccg_ot["fixed_modulation_frequency"] == 55)],
        am_single_ccg_ot[(am_single_ccg_ot["modulation_frequency"] == 55)],
        am_single_rlf_ot[(am_single_rlf_ot["modulation_frequency"] == 55)],
        axs=axs.flatten().tolist(),
        color=c_fAM_55[0],
    )
    for ax in axs[0, :]:
        ax.set_ylim(top=0.000165)
    display(fig)
    plt.close(fig)
    del fig, axs, axg

In [None]:
df_am55_spkresp = am_single_rlf_ot[(am_single_rlf_ot["modulation_frequency"] == 55)]
df_ts_ccg = am_twostim_ccg_ot[(am_twostim_ccg_ot["fixed_modulation_frequency"] == 55)]
df_am55_ccg = am_single_ccg_ot[(am_single_ccg_ot["modulation_frequency"] == 55)]

twostim_ccg = df_ts_ccg.copy()
single_rlf = df_am55_spkresp.copy()


merged_ccg = twostim_ccg.join(single_rlf, on = ['date','owl','channel1'], how="inner", rsuffix="_single")

merged_ccg = merged_ccg.loc[
    (merged_ccg["fixedintensity"] == merged_ccg["intensity"])
    & (merged_ccg["fixedazi"] == merged_ccg["azimuth"])
    & (merged_ccg["fixedele"] == merged_ccg["elevation"])
]

pl_units = merged_ccg[merged_ccg["short_latency"]].set_index("relative_level")
npl_units = merged_ccg[merged_ccg["long_latency"]].set_index("relative_level")

print("Number of unit pairs:", np.unique(merged_ccg.index).size)
print(
    "Number of sessions:",
    len(sorted(set([idx[:2] for idx in merged_ccg.index.values]))),
)
merged_ccg.set_index("relative_level", inplace=True)

pl_means = pl_units.groupby('relative_level')['xcorr_peak'].mean()
npl_means = npl_units.groupby('relative_level')['xcorr_peak'].mean()


In [None]:
def competition_synchrony_csi(
    single_rlf: pd.DataFrame,
    single_ccg: pd.DataFrame,
    twostim_ccg: pd.DataFrame,
    *,
    ax: Axes,
    color: ColorType = c_flat_noise[0],
):

    merge_single = single_ccg.join(
        single_rlf,
        on=["date", "owl", "channel1"],
        how="inner",
        rsuffix="_rlf",
    ).join(
        single_rlf,
        on=["date", "owl", "channel2"],
        how="left",
        rsuffix="_rlf2",
    )

    merged_ccg = twostim_ccg.join(merge_single, how="inner", rsuffix="_single_ccg")

    merged_ccg = merged_ccg.loc[
        (merged_ccg["fixedintensity"] == merged_ccg["intensity"])
        & (merged_ccg["fixedintensity"] == merged_ccg["intensity_rlf"])
        & (merged_ccg["fixedintensity"] == merged_ccg["intensity_rlf2"])
        & (merged_ccg["fixedazi"] == merged_ccg["azimuth"])
        & (merged_ccg["fixedazi"] == merged_ccg["azimuth_rlf"])
        & (merged_ccg["fixedazi"] == merged_ccg["azimuth_rlf2"])
        & (merged_ccg["fixedele"] == merged_ccg["elevation"])
        & (merged_ccg["fixedele"] == merged_ccg["elevation_rlf"])
        & (merged_ccg["fixedele"] == merged_ccg["elevation_rlf2"])
    ]

    print("Number of unit pairs:", np.unique(merged_ccg.index).size)
    print(
        "Number of sessions:",
        len(sorted(set([idx[:2] for idx in merged_ccg.index.values]))),
    )

    merged_ccg.set_index("relative_level", inplace=True)
    relative_levels = np.unique(merged_ccg.index)
    merged_ccg["csi"] = (
        merged_ccg["xcorr_peak"] - merged_ccg["xcorr_peak_single_ccg"]
    ) / (merged_ccg["xcorr_peak"] + merged_ccg["xcorr_peak_single_ccg"])

    pl_units = merged_ccg[
        merged_ccg["short_latency"] & merged_ccg["short_latency_rlf2"]
    ]
    npl_units = merged_ccg[merged_ccg["long_latency"] & merged_ccg["long_latency_rlf2"]]

    pl_means = pl_units.groupby("relative_level")["csi"].mean()
    npl_means = npl_units.groupby("relative_level")["csi"].mean()

    pl_sems = pl_units.groupby("relative_level")["csi"].sem()
    npl_sems = npl_units.groupby("relative_level")["csi"].sem()

    h1 = ax.errorbar(pl_means.index, pl_means, pl_sems, color=color, label="Short Lat.")
    h2 = ax.errorbar(
        npl_means.index, npl_means, npl_sems, color="0.6", label="Long Lat."
    )

    t_test_results = t_test_ind(pl_units, npl_units, val_col="csi")
    display(t_test_results)
    for relative_level, row in t_test_results.iterrows():
        relative_level = cast(float, relative_level)
        if True or row["p"] < 0.05:
            ax.text(
                relative_level,
                0.0,
                f"{row['p']:.2g}".removeprefix("0"),
                transform=blended_transform_factory(ax.transData, ax.transAxes)
                + ScaledTranslation(0 / 72, +2 / 72, cast(Figure, ax.figure).dpi_scale_trans),
                fontsize=8,
                ha="center",
                va="bottom",
            )

    ax.axhline(0, color="k", ls=":")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_xlabel("Relative Level [dB]")
    ax.set_ylabel("CSI")
    ax.set_xlim(left=relative_levels[0] - 2, right=relative_levels[-1] + 2)
    ax.legend(
        handles=[h1[0], h2[0]],
        labels=[h1.get_label(), h2.get_label()],
        frameon=False,
        loc="upper right",
    )
    ax.set_xticks(relative_levels.astype(int), relative_levels.astype(int))


if do_usage:
    fig = cast(Figure, plt.figure(figsize=(3.30, 2.10)))
    fig.suptitle("fAM 55")
    ax = figure_add_axes_inch(
        fig,
        top=0.1,
        left=0.8,
        width=2.4,
        height=1.5,
        label="-",
    )
    df_am55_spkresp = am_single_rlf_ot[(am_single_rlf_ot["modulation_frequency"] == 55)]
    df_ts_ccg = am_twostim_ccg_ot[
        (am_twostim_ccg_ot["fixed_modulation_frequency"] == 55)
    ]
    df_am55_ccg = am_single_ccg_ot[(am_single_ccg_ot["modulation_frequency"] == 55)]

    competition_synchrony_csi(
        df_am55_spkresp, df_am55_ccg, df_ts_ccg, color=c_fAM_55[0], ax=ax
    )
    plt.ylim(-0.17, 0.15)

    fig = cast(Figure, plt.figure(figsize=(3.30, 2.10)))
    fig.suptitle("Flat noise")
    ax = figure_add_axes_inch(
        fig,
        top=0.1,
        left=0.8,
        width=2.4,
        height=1.5,
        label="-",
    )

    competition_synchrony_csi(
        single_rlf_ot, single_ccg_ot, twostim_ccg_ot, color=c_flat_noise[0], ax=ax
    )
    plt.ylim(-0.17, 0.15)

### Synchrony and Gamma 

In [None]:
def synchrony_correlation_plot(
    twostim_ccg: pd.DataFrame,
    twostim_gamma: pd.DataFrame,
    twostim_phase_locking: pd.DataFrame,
    am_single_rlf: pd.DataFrame,
    axs: Sequence[Axes],
):

    twostim_ccg = twostim_ccg.copy()
    twostim_gamma = twostim_gamma.copy()
    twostim_stim_phaselocking = twostim_phase_locking.copy()
    am_single_rlf = am_single_rlf.copy()

    twostim_ccg.set_index("relative_level", append=True, inplace=True)
    twostim_ccg.set_index("fixed_modulation_frequency", append=True, inplace=True)
    twostim_ccg.sort_index(inplace=True)

    twostim_gamma.set_index("relative_level", append=True, inplace=True)
    twostim_gamma.set_index("fixed_modulation_frequency", append=True, inplace=True)
    twostim_gamma.sort_index(inplace=True)

    twostim_ccg = twostim_ccg.join(
        twostim_gamma,
        on=["date", "owl", "channel1", "relative_level", "fixed_modulation_frequency"],
        how="inner",
        rsuffix="1",
    )

    twostim_ccg = twostim_ccg.join(
        twostim_gamma,
        on=["date", "owl", "channel2", "relative_level", "fixed_modulation_frequency"],
        how="inner",
        rsuffix="2",
    )

    twostim_stim_phaselocking.set_index("relative_level", append=True, inplace=True)
    twostim_stim_phaselocking.set_index(
        "fixed_modulation_frequency", append=True, inplace=True
    )
    twostim_stim_phaselocking.sort_index(inplace=True)
    twostim_ccg = twostim_ccg.join(
        twostim_stim_phaselocking,
        on=["date", "owl", "channel1", "relative_level", "fixed_modulation_frequency"],
        how="inner",
        rsuffix="1",
    ).join(
        twostim_stim_phaselocking,
        on=["date", "owl", "channel2", "relative_level", "fixed_modulation_frequency"],
        how="inner",
        rsuffix="2",
    )

    merge = twostim_ccg.join(
        am_single_rlf,
        on=["date", "owl", "channel1"],
        rsuffix="_rlf",
        how="inner",
    ).join(
        am_single_rlf,
        on=["date", "owl", "channel2"],
        rsuffix="_rlf2",
        how="inner",
    )
    merge = merge.loc[
        (merge["fixedintensity"] == merge["intensity"])
        & (merge["fixedintensity"] == merge["intensity_rlf2"])
        & (merge["fixedazi"] == merge["azimuth"])
        & (merge["fixedazi"] == merge["azimuth_rlf2"])
        & (merge["fixedele"] == merge["elevation"])
        & (merge["fixedele"] == merge["elevation_rlf2"])
    ]

    # # # twostim_ccg.index.names
    merge["mean_gamma_plv"] = (merge["gamma_plv"] + merge["gamma_plv2"]) / 2
    merge["diff_gamma_plv"] = np.abs(merge["gamma_plv"] - merge["gamma_plv2"])
    merge["min_gamma_plv"] = np.max([merge["gamma_plv"], merge["gamma_plv2"]], axis=0)

    merge["diff_gamma_plv_angle"] = np.abs(
        merge["gamma_plv_angle"] - merge["gamma_plv_angle2"]
    )

    merge["mean_stim_plv"] = (merge["fixedstim_plv"] + merge["fixedstim_plv2"]) / 2
    merge["diff_stim_plv_angle"] = np.abs(
        merge["fixedstim_plv_angle"] - merge["fixedstim_plv_angle2"]
    )

    pl_units = merge[merge["short_latency"] & merge["short_latency_rlf2"]]
    npl_units = merge[merge["long_latency"] & merge["long_latency_rlf2"]]

    r_pl_gamma = np.corrcoef(pl_units["mean_gamma_plv"], pl_units["xcorr_peak"])[0, 1]
    r_npl_gamma = np.corrcoef(npl_units["mean_gamma_plv"], npl_units["xcorr_peak"])[
        0, 1
    ]

    r_pl_stim = np.corrcoef(pl_units["mean_stim_plv"], pl_units["xcorr_peak"])[0, 1]
    r_npl_stim = np.corrcoef(npl_units["mean_stim_plv"], npl_units["xcorr_peak"])[0, 1]

    axs[0].scatter(
        pl_units["mean_gamma_plv"],
        pl_units["xcorr_peak"],
        c=c_fAM_55[0],
        marker=".",
        edgecolor="none",
    )

    axs[0].text(
        0.1,
        1,
        f"R = {r_pl_gamma:.2f}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=axs[0].transAxes,
    )
    axs[0].set_xlabel("VS Gamma LFP")

    axs[2].scatter(
        npl_units["mean_gamma_plv"],
        npl_units["xcorr_peak"],
        c="0.5",
        alpha=0.5,
        marker=".",
        edgecolor="none",
    )

    axs[2].text(
        0.1,
        1,
        f"R = {r_npl_gamma:.2f}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=axs[2].transAxes,
    )
    axs[2].set_xlabel("VS Gamma LFP")

    axs[1].scatter(
        pl_units["mean_stim_plv"],
        pl_units["xcorr_peak"],
        c=c_fAM_55[0],
        marker=".",
        edgecolor="none",
    )

    axs[1].text(
        0.1,
        1,
        f"R = {r_pl_stim:.2f}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=axs[1].transAxes,
    )

    axs[1].set_xlabel("VS Stim $f_{AM}$")
    axs[3].scatter(
        npl_units["mean_stim_plv"],
        npl_units["xcorr_peak"],
        c="0.5",
        alpha=0.5,
        marker=".",
        edgecolor="none",
    )

    axs[3].text(
        0.1,
        1,
        f"R = {r_npl_stim:.2f}",
        horizontalalignment="left",
        verticalalignment="top",
        fontsize=10,
        transform=axs[3].transAxes,
    )
    axs[3].set_xlabel("VS Stim $f_{AM}$")

    for ax in axs:
        ax.set_ylim(bottom=0, top=0.0006)
        ax.set_xlim(left=0)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    def formatter_y(x, pos):
        if x > 0:
            return f"${{{x*1e4:.0f}}}$"
        else:
            return "0"

    axs[0].set_yticks([0, 5e-4])
    axs[0].yaxis.set_major_formatter(FuncFormatter(formatter_y))
    axs[0].set_yticks(np.array([1, 2, 3, 4, 6, 7]) * 1e-4, minor=True)
    axs[0].text(
        0,
        1,
        "$10^{-4}$",
        ha="right",
        va="center",
        fontsize=8,
        transform=axs[0].transAxes
        + ScaledTranslation(
            -2 / 72,
            0,
            cast(Figure, axs[0].figure).dpi_scale_trans,
        ),
    )

    axs[2].set_yticks([0, 5e-4])
    axs[2].yaxis.set_major_formatter(FuncFormatter(formatter_y))
    axs[2].set_yticks(np.array([1, 2, 3, 4, 6, 7]) * 1e-4, minor=True)
    axs[2].text(
        0,
        1,
        "$10^{-4}$",
        ha="right",
        va="center",
        fontsize=8,
        transform=axs[0].transAxes
        + ScaledTranslation(
            -2 / 72,
            0,
            cast(Figure, axs[2].figure).dpi_scale_trans,
        ),
    )

    axs[1].set_yticks([])
    axs[1].set_yticks(np.array([1, 2, 3, 4, 5, 6, 7]) * 1e-4, minor=True)
    axs[3].set_yticks([])
    axs[3].set_yticks(np.array([1, 2, 3, 4, 5, 6, 7]) * 1e-4, minor=True)

    for ax in axs:
        ax.set_xlim(0, 1.0)

# Assembled Figures

## Figure 1 (SRF, Paradigm, Flat Response)

**Figure 1. Spike response rates in OT decrease with competition of flat
broadband noise.** __A)__ Spatial receptive field of an example OT unit. Color
intensity in each circle represents the spike response rate measured for stimuli
played from that speaker location. __B)__ Experimental design: The driver
stimulus (blue) was played from a speaker located at the preferred location of
recorded units at a fixed level. The competitor stimulus (green) was played from
a speaker outside the receptive field greater than 50 degrees away. Sound level
of the competitor varied from -15 to 10 dB relative to the driver level. __C)__
Percent change in spike response rates for relative level of competing sounds.
Box plots show the interquartile range (25th-75th percentile) with whiskers
indicating 5-95% confidence interval, and circles represent individual units.
ANOVA p-value and relevant post hoc p-values are indicated.

In [None]:
fig1 = cast(Figure, plt.figure(figsize=(3.3, 6.9)))

ax1a = figure_add_axes_inch(
    fig1,
    top=0.05,
    left=0.8,
    width=1.9,
    height=2.0,
    label="A",
)

ax1a_cb = figure_add_axes_inch(
    fig1,
    top=0.05,
    left=0.8 + 1.9 + 0.05,
    width=0.05,
    height=2.0,
)

figure_spatial_receptive_field(example_srf, ax=ax1a, cax=ax1a_cb)
subplot_indicator(ax1a, ha="left", va="top", pad_inch=0.7)


ax1b = figure_add_axes_inch(
    fig1,
    top=2.5,
    left=0.8,
    width=2.1,
    height=2.4,
    label="B",
)
plt.setp(ax1b, frame_on=False, xticks=[], yticks=[], zorder=20)
owl_image = imread("./other_figures/owl_experimental_design.png")
ax1b.imshow(owl_image)

subplot_indicator(ax1b, ha="left", va="top", pad_inch=0.7)


ax1c = figure_add_axes_inch(
    fig1,
    top=4.9,
    left=0.8,
    width=2.4,
    height=1.5,
    label="C",
)
figure_rlf_boxplot(
    twostim_rlf_ot, single_rlf_ot, ax=ax1c, colors=c_flat_noise
)
ax1c.set_ylabel("Change in\nSpike Rate")
ax1c.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax1c, ha="left", va="top", pad_inch=0.7)

save_show_close(fig1, "figure1")


## Figure 2 (OT - Flat)

**Figure 2. Spike train synchrony decreases with competition of flat broadband
noise.** __A)__ Example raster plots for two relative level conditions. Spike
times of two simultaneously recorded units (green and orange) show more
coincident spike times between the unit pair (black) when the driver was louder
(left, -15 dB) than when the competitor was louder (right, +10 dB). __B)__
Cross-correlograms for concurrent stimuli, averaged across all simultaneously
recorded unit pairs with significant spike train synchrony. Horizontal line
indicates mean peak values for a single driver stimulus. Number of unit pairs
indicated above each cross-correlogram. __C)__ Competition synchrony index
(CSI), showing the change in spike train synchrony with competition relative to
single stimulus presentation, across relative levels. Box plots show the
interquartile range (25th-75th percentile) with whiskers indicating 5-95%
confidence interval, and circles represent individual units. ANOVA p-value and
relevant post hoc p-values are indicated.

In [None]:
fig2 = cast(Figure, plt.figure(figsize=(3.30, 5.7-float.fromhex("0x1.0000000000000p-51"))))

axs2a, axg2a = figure_add_axes_group_inch(
    fig2,
    nrows=1,
    ncols=2,
    group_top=0.3,
    group_left=0.8,
    individual_width=(2.4 - 0.1) / 2,
    individual_height=2.0,
    wspace=0.1,
    hspace=0.0,
)

figure_coincident_rasterplot(
    spiketrain_data=example_spiketrains_ot_flat,
    axs=axs2a.flatten().tolist(),
)
axs2a[0, 0].set_title("-15 dB")
axs2a[0, 1].set_title("+10 dB")

axs2b, axg2b = figure_add_axes_group_inch(
    fig2,
    nrows=1,
    ncols=6,
    group_top=2.8,
    group_left=0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)

figure_xcorr_ccg(
    twostim_ccg_ot, single_ccg_ot, axs=axs2b.flatten().tolist(), colors=c_flat_noise
)

ax2c = figure_add_axes_inch(
    fig2,
    left=0.8,
    width=2.4,
    bottom=0.5,
    height=1.5,
    label="C",
)
figure_xcorr_boxplot(twostim_ccg_ot, single_ccg_ot, ax=ax2c, colors=c_flat_noise)

subplot_indicator(axg2a, "A", ha="left", va="top", pad_inch=0.7)

axs2b[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)
subplot_indicator(axg2b, "B", ha="left", va="top", pad_inch=0.7)
ax2c.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(ax2c, ha="left", va="top", pad_inch=0.7)

condition_batch(fig2, 0.8, 0.3, text="flat noise", fontsize=10, color=c_flat_noise[0], y_pt=2, x_pt=0, pad_pt=2)

save_show_close(fig2, "figure2")


## Figure 3 (OT - AM)

**Figure 3. Spike response rates and spike train synchrony in OT decrease with
competition of amplitude modulated noise.** Competition between stimuli with
different modulation frequencies (left column: driver fAM = 55 Hz, competitor
fAM = 75 Hz; right column: driver fAM = 75 Hz, competitor fAM = 55 Hz). __A,B)__
Percent change in spike response rates for relative level of competing sounds.
__C,D)__ Example raster plots for two relative level conditions. Spike times of
two simultaneously recorded units (green and orange) show more coincident spike
times between the unit pair (black) when the driver was louder (left, -15 dB)
than when the competitor was louder (right, +10 dB). __E,F)__ Cross-correlograms
for concurrent stimuli, averaged across all simultaneously recorded unit pairs
with significant spike train synchrony. Horizontal lines indicate mean peak
values for a single driver stimulus. Number of unit pairs indicated above each
cross-correlogram. G,H) Competition synchrony index (CSI), showing the change in
spike train synchrony with competition relative to single stimulus presentation,
across relative levels. All box plots show the interquartile range (25th-75th
percentile) with whiskers indicating 5-95% confidence interval, and circles
represent individual units. ANOVA p-values and relevant post hoc p-values are
indicated.

In [None]:
fig3 = cast(Figure, plt.figure(figsize=(6.80, 7.8-float.fromhex("0x1.0000000000000p-51"))))

am_single_ccg_ot_55 = am_single_ccg_ot[am_single_ccg_ot["modulation_frequency"] == 55]
am_single_rlf_ot_55 = am_single_rlf_ot[am_single_rlf_ot["modulation_frequency"] == 55]
am_twostim_ccg_ot_55 = am_twostim_ccg_ot[
    am_twostim_ccg_ot["fixed_modulation_frequency"] == 55
]
am_twostim_rlf_ot_55 = am_twostim_rlf_ot[
    am_twostim_rlf_ot["fixed_modulation_frequency"] == 55
]

am_single_ccg_ot_75 = am_single_ccg_ot[am_single_ccg_ot["modulation_frequency"] == 75]
am_single_rlf_ot_75 = am_single_rlf_ot[am_single_rlf_ot["modulation_frequency"] == 75]
am_twostim_ccg_ot_75 = am_twostim_ccg_ot[
    am_twostim_ccg_ot["fixed_modulation_frequency"] == 75
]
am_twostim_rlf_ot_75 = am_twostim_rlf_ot[
    am_twostim_rlf_ot["fixed_modulation_frequency"] == 75
]


ax3a = figure_add_axes_inch(
    fig3,
    top=0.1,
    left=0.8,
    width=2.4,
    height=1.5,
    label="A",
)
figure_rlf_boxplot(am_twostim_rlf_ot_55, am_single_rlf_ot_55, ax=ax3a, colors=c_fAM_55)
ax3a.set_ylabel("Change in\nSpike Rate")
ax3a.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax3a, ha="left", va="top", pad_inch=0.7)


ax3b = figure_add_axes_inch(
    fig3,
    top=0.1,
    left=0.8 + 2.4 + 0.3 + 0.8,
    width=2.4,
    height=1.5,
    label="B",
)
figure_rlf_boxplot(am_twostim_rlf_ot_75, am_single_rlf_ot_75, ax=ax3b, colors=c_fAM_75)

ax3b.set_ylabel("Change in\nSpike Rate")
ax3b.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax3b, ha="left", va="top", pad_inch=0.7)


axs3c, axg3c = figure_add_axes_group_inch(
    fig3,
    nrows=1,
    ncols=2,
    group_top=2.4,
    group_left=0.8,
    individual_width=(2.4 - 0.1) / 2,
    individual_height=2.0,
    wspace=0.1,
    hspace=0.0,
)

figure_coincident_rasterplot(
    spiketrain_data=example_spiketrains_ot_driver55,
    axs=axs3c.flatten().tolist(),
)
axs3c[0, 0].set_title("-15 dB")
axs3c[0, 1].set_title("+10 dB")
subplot_indicator(axg3c, "C", ha="left", va="top", pad_inch=0.7)


axs3d, axg3d = figure_add_axes_group_inch(
    fig3,
    nrows=1,
    ncols=2,
    group_top=2.4,
    group_left=0.8 + 2.4 + 0.3 + 0.8,
    individual_width=(2.4 - 0.1) / 2,
    individual_height=2.0,
    wspace=0.1,
    hspace=0.0,
)

figure_coincident_rasterplot(
    spiketrain_data=example_spiketrains_ot_driver75,
    axs=axs3d.flatten().tolist(),
)
axs3d[0, 0].set_title("-15 dB")
axs3d[0, 1].set_title("+10 dB")
subplot_indicator(axg3d, "D", ha="left", va="top", pad_inch=0.7)


axs3e, axg3e = figure_add_axes_group_inch(
    fig3,
    nrows=1,
    ncols=6,
    group_top=2.2 + 2.6 + 0.1,
    group_left=0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)
figure_xcorr_ccg(
    am_twostim_ccg_ot_55,
    am_single_ccg_ot_55,
    axs=axs3e.flatten().tolist(),
    colors=c_fAM_55,
)
for ax in axs3e[0, :]:
    ax.set_ylim(top=9.5e-5)
axs3e[0, 0].set_yticks(np.array([1, 2, 3, 4, 6, 7, 8, 9]) * 1e-5, minor=True)
axs3e[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)
subplot_indicator(axg3e, "E", ha="left", va="top", pad_inch=0.7)


axs3f, axg3f = figure_add_axes_group_inch(
    fig3,
    nrows=1,
    ncols=6,
    group_top=2.2 + 2.6 + 0.1,
    group_left=0.8 + 2.4 + 0.3 + 0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)
figure_xcorr_ccg(
    am_twostim_ccg_ot_75,
    am_single_ccg_ot_75,
    axs=axs3f.flatten().tolist(),
    colors=c_fAM_75,
)
for ax in axs3f[0, :]:
    ax.set_ylim(top=9.5e-5)
axs3f[0, 0].set_yticks(np.array([1, 2, 3, 4, 6, 7, 8, 9]) * 1e-5, minor=True)
axs3f[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)
subplot_indicator(axg3f, "F", ha="left", va="top", pad_inch=0.7)

ax3g = figure_add_axes_inch(
    fig3,
    left=0.8,
    width=2.4,
    bottom=0.5,
    height=1.5,
    label="G",
)
figure_xcorr_boxplot(
    am_twostim_ccg_ot_55, am_single_ccg_ot_55, ax=ax3g, colors=c_fAM_55
)

ax3g.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(ax3g, ha="left", va="top", pad_inch=0.7)

ax3h = figure_add_axes_inch(
    fig3,
    left=0.8 + 2.4 + 0.3 + 0.8,
    width=2.4,
    bottom=0.5,
    height=1.5,
    label="H",
)
figure_xcorr_boxplot(
    am_twostim_ccg_ot_75, am_single_ccg_ot_75, ax=ax3h, colors=c_fAM_75
)
ax3h.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(ax3h, ha="left", va="top", pad_inch=0.7)

condition_batch(
    fig3,
    left=0.8,
    top=2.3,
    text="Driver\n55 Hz",
    color=c_fAM_55[0],
    fontsize=12,
    ha="right",
)
condition_batch(
    fig3,
    left=0.8 + 2.4 + 0.3 + 0.8,
    top=2.3,
    text="Driver\n75 Hz",
    color=c_fAM_75[0],
    fontsize=12,
    ha="right",
)

save_show_close(fig3, "figure3")

## Figure 4 (ICx Responses)

**Figure 4. Spike response rates decrease with competition of flat and amplitude
modulated noise in ICx.** __A)__ Percent change in spike response rates for
relative level of competing flat noise. __B,C)__ Percent change in spike
response rates for relative level of competing amplitude modulated noise (B:
driver fAM = 55 Hz, competitor fAM = 75 Hz; C: driver fAM = 75 Hz, competitor
fAM = 55 Hz). All box plots show the interquartile range (25th-75th percentile)
with whiskers indicating 5-95% confidence interval, and circles represent
individual units. ANOVA p-values and relevant post hoc p-values are indicated.

In [None]:
fig4 = cast(Figure, plt.figure(figsize=(3.30, 6.3 + float.fromhex("0x1.0000000000000p-51"))))

ax4a = figure_add_axes_inch(
    fig4,
    top=0.1,
    left=0.8,
    width=2.4,
    height=1.5,
    label="A",
)
figure_rlf_boxplot(
    twostim_rlf_icx,
    single_rlf_icx,
    ax=ax4a,
    colors=c_flat_noise,
    brackets={(0, 5): 0.85, (1, 5): 0.75, (2, 5): 0.65},
    anova_align="left",
)
ax4a.set_ylim(top=1.9)
ax4a.yaxis.set_major_locator(MultipleLocator(1))
ax4a.yaxis.set_minor_locator(MultipleLocator(0.5))
ax4a.set_ylabel("Change in\nSpike Rate")
ax4a.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax4a, ha="right", va="top", pad_inch=0.05)

condition_batch(
    fig4,
    left=0.0,
    top=0.0,
    text="flat noise",
    color=c_flat_noise[0],
    fontsize=8,
    ha="left",
    va="top",
    pad_pt=2,
    y_pt=4,
    x_pt=4,
)

am_single_rlf_icx_55 = am_single_rlf_icx[
    am_single_rlf_icx["modulation_frequency"] == 55
]
am_twostim_rlf_icx_55 = am_twostim_rlf_icx[
    am_twostim_rlf_icx["fixed_modulation_frequency"] == 55
]

ax4b = figure_add_axes_inch(
    fig4,
    top=2.2,
    left=0.8,
    width=2.4,
    height=1.5,
    label="B",
)
figure_rlf_boxplot(
    am_twostim_rlf_icx_55,
    am_single_rlf_icx_55,
    ax=ax4b,
    colors=c_fAM_55,
    brackets={(0, 5): 0.85, (1, 5): 0.75, (2, 5): 0.65},
    anova_align="left",
)
ax4b.set_ylim(top=1.9)
ax4b.yaxis.set_major_locator(MultipleLocator(1))
ax4b.yaxis.set_minor_locator(MultipleLocator(0.5))
ax4b.set_ylabel("Change in\nSpike Rate")
ax4b.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax4b, ha="right", va="top", pad_inch=0.05)

condition_batch(
    fig4,
    left=0.0,
    top=2.1,
    text="Driver\n55 Hz",
    color=c_fAM_55[0],
    fontsize=8,
    ha="left",
    va="top",
    x_pt=4,
    y_pt=4,
)

am_single_rlf_icx_75 = am_single_rlf_icx[
    am_single_rlf_icx["modulation_frequency"] == 75
]
am_twostim_rlf_icx_75 = am_twostim_rlf_icx[
    am_twostim_rlf_icx["fixed_modulation_frequency"] == 75
]

ax4c = figure_add_axes_inch(
    fig4,
    top=4.3+float.fromhex("0x1.0000000000000p-51"),
    left=0.8,
    width=2.4,
    height=1.5,
    label="C",
)
figure_rlf_boxplot(
    am_twostim_rlf_icx_75,
    am_single_rlf_icx_75,
    ax=ax4c,
    colors=c_fAM_75,
    brackets={(0, 5): 0.9, (1, 5): 0.8, (2, 5): 0.7},
    anova_align="left",
)
ax4c.set_ylim(top=1.9)
ax4c.yaxis.set_major_locator(MultipleLocator(1))
ax4c.yaxis.set_minor_locator(MultipleLocator(0.5))
ax4c.set_ylabel("Change in\nSpike Rate")
ax4c.yaxis.set_label_coords(-0.4 / 2.4, 0.5)
subplot_indicator(ax4c, ha="right", va="top", pad_inch=0.05)

condition_batch(
    fig4,
    left=0.0,
    top=4.2,
    text="Driver\n75 Hz",
    color=c_fAM_75[0],
    fontsize=8,
    ha="left",
    va="top",
    x_pt=4,
    y_pt=4,
)

save_show_close(fig4, "figure4")


## Figure 5 (ICx - Flat)

**Figure 5. Spike train synchrony does not change with competition in ICx for flat noise.** __A)__ Example raster plots for two relative level conditions. Spike times of two simultaneously recorded units (green and orange) show comparable coincident spike times between the unit pair (black) when either the driver (left, -15 dB) or competitor (right, +10 dB) was louder. __B)__ Cross-correlograms for concurrent stimuli, averaged across CCGs of all simultaneously recorded unit pairs with significant spike train synchrony. Horizontal line indicates mean peak values for a single driver stimulus. Number of unit pairs indicated above each cross-correlogram. __C)__ Competition synchrony index (CSI), showing the change in spike train synchrony with competition relative to single stimulus presentation, across relative levels. Box plots show the interquartile range (25th-75th percentile) with whiskers indicating 5-95% confidence interval, and circles represent individual units. ANOVA p-values and relevant post hoc p-values are indicated.

In [None]:
fig5 = cast(Figure, plt.figure(figsize=(3.30, 5.7 - float.fromhex("0x1.0000000000000p-51"))))

axs5a, axg5a = figure_add_axes_group_inch(
    fig5,
    nrows=1,
    ncols=2,
    group_top=0.3,
    group_left=0.8,
    individual_width=(2.4 - 0.1) / 2,
    individual_height=2.0,
    wspace=0.1,
    hspace=0.0,
)

figure_coincident_rasterplot(
    spiketrain_data=example_spiketrains_icx_flat,
    axs=axs5a.flatten().tolist(),
)
axs5a[0, 0].set_title("-15 dB")
axs5a[0, 1].set_title("+10 dB")

axs5b, axg5b = figure_add_axes_group_inch(
    fig5,
    nrows=1,
    ncols=6,
    group_top=2.8,
    group_left=0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)

figure_xcorr_ccg(
    twostim_ccg_icx, single_ccg_icx, axs=axs5b.flatten().tolist(), colors=c_flat_noise
)

for ax in axs5b[0, :]:
    ax.set_ylim(top=1.5e-5)
axs5b[0, 0].set_yticks(np.array([0, 1]) * 1e-5, minor=False)
axs5b[0, 0].set_yticks(np.array([0.5]) * 1e-5, minor=True)
axs5b[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)


ax2c = figure_add_axes_inch(
    fig5,
    left=0.8,
    width=2.4,
    bottom=0.5,
    height=1.5,
    label="C",
)
figure_xcorr_boxplot(
    twostim_ccg_icx, single_ccg_icx, ax=ax2c, colors=c_flat_noise, brackets={}
)

subplot_indicator(axg5a, "A", ha="left", va="top", pad_inch=0.7)

axs5b[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)
subplot_indicator(axg5b, "B", ha="left", va="top", pad_inch=0.7)
ax2c.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(ax2c, ha="left", va="top", pad_inch=0.7)

condition_batch(
    fig5,
    0.8,
    0.3,
    text="flat noise",
    fontsize=10,
    color=c_flat_noise[0],
    y_pt=2,
    x_pt=0,
    pad_pt=2,
)

save_show_close(fig5, "figure5")

## Figure 6 (ICx - AM)

**Figure 6. Spike train synchrony does not change with competition in ICx for amplitude modulated noise.** Competition with amplitude modulated noise (right column: driver fAM = 55 Hz, competitor fAM = 75 Hz; left column: driver fAM = 75 Hz, competitor fAM = 55 Hz). __A,B)__ Example raster plots for two relative level conditions. Spike times of two simultaneously recorded units (green and orange) show comparable coincident spike times between the unit pair (black) when either the driver (left, -15 dB) or competitor (right, +10 dB) was louder. __C,D)__ Cross-correlograms for concurrent stimuli, averaged across all simultaneously recorded unit pairs with significant spike train synchrony. Horizontal line indicates mean peak values for a single driver stimulus. Number of unit pairs indicated above each cross-correlogram. __E,F)__ Competition synchrony index (CSI), showing the change in spike train synchrony with competition relative to single stimulus presentation, across relative levels. All box plots show the interquartile range (25th-75th percentile) with whiskers indicating 5-95% confidence interval, and circles represent individual units. ANOVA p-values and relevant post hoc p-values are indicated.

In [None]:
fig6 = cast(Figure, plt.figure(figsize=(6.80, 5.7 - float.fromhex("0x1.0000000000000p-51"))))

am_single_ccg_icx_55 = am_single_ccg_icx[
    am_single_ccg_icx["modulation_frequency"] == 55
]
am_single_rlf_icx_55 = am_single_rlf_icx[
    am_single_rlf_icx["modulation_frequency"] == 55
]
am_twostim_ccg_icx_55 = am_twostim_ccg_icx[
    am_twostim_ccg_icx["fixed_modulation_frequency"] == 55
]
am_twostim_rlf_icx_55 = am_twostim_rlf_icx[
    am_twostim_rlf_icx["fixed_modulation_frequency"] == 55
]

am_single_ccg_icx_75 = am_single_ccg_icx[
    am_single_ccg_icx["modulation_frequency"] == 75
]
am_single_rlf_icx_75 = am_single_rlf_icx[
    am_single_rlf_icx["modulation_frequency"] == 75
]
am_twostim_ccg_icx_75 = am_twostim_ccg_icx[
    am_twostim_ccg_icx["fixed_modulation_frequency"] == 75
]
am_twostim_rlf_icx_75 = am_twostim_rlf_icx[
    am_twostim_rlf_icx["fixed_modulation_frequency"] == 75
]

axs6a, axg6a = figure_add_axes_group_inch(
    fig6,
    nrows=1,
    ncols=2,
    group_top=0.3,
    group_left=0.8,
    individual_width=(2.4 - 0.1) / 2,
    individual_height=2.0,
    wspace=0.1,
    hspace=0.0,
)

figure_coincident_rasterplot(
    spiketrain_data=example_spiketrains_icx_driver55,
    axs=axs6a.flatten().tolist(),
)
axs6a[0, 0].set_title("-15 dB")
axs6a[0, 1].set_title("+10 dB")
subplot_indicator(axg6a, "A", ha="left", va="top", pad_inch=0.7)


axs6b, axg6b = figure_add_axes_group_inch(
    fig6,
    nrows=1,
    ncols=2,
    group_top=0.3,
    group_left=4.3,
    individual_width=(2.4 - 0.1) / 2,
    individual_height=2.0,
    wspace=0.1,
    hspace=0.0,
)

figure_coincident_rasterplot(
    spiketrain_data=example_spiketrains_icx_driver75,
    axs=axs6b.flatten().tolist(),
)
axs6b[0, 0].set_title("-15 dB")
axs6b[0, 1].set_title("+10 dB")
subplot_indicator(axg6b, "B", ha="left", va="top", pad_inch=0.7)


axs6c, axg6c = figure_add_axes_group_inch(
    fig6,
    nrows=1,
    ncols=6,
    group_top=2.8,
    group_left=0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)
figure_xcorr_ccg(
    am_twostim_ccg_icx_55,
    am_single_ccg_icx_55,
    axs=axs6c.flatten().tolist(),
    colors=c_fAM_55,
)

for ax in axs6c[0, :]:
    ax.set_ylim(top=1.5e-5)
axs6c[0, 0].set_yticks(np.array([0, 1]) * 1e-5, minor=False)
axs6c[0, 0].set_yticks(np.array([0.5]) * 1e-5, minor=True)
axs6c[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)

subplot_indicator(axg6c, "C", ha="left", va="top", pad_inch=0.7)


axs6d, axg6d = figure_add_axes_group_inch(
    fig6,
    nrows=1,
    ncols=6,
    group_top=2.8,
    group_left=4.3,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)
figure_xcorr_ccg(
    am_twostim_ccg_icx_75,
    am_single_ccg_icx_75,
    axs=axs6d.flatten().tolist(),
    colors=c_fAM_75,
)

for ax in axs6d[0, :]:
    ax.set_ylim(top=1.5e-5)
axs6d[0, 0].set_yticks(np.array([0, 1]) * 1e-5, minor=False)
axs6d[0, 0].set_yticks(np.array([0.5]) * 1e-5, minor=True)
axs6d[0, 0].yaxis.set_label_coords(-0.3 / (2.4 / 6), 0.5)

subplot_indicator(axg6d, "D", ha="left", va="top", pad_inch=0.7)

ax6e = figure_add_axes_inch(
    fig6,
    left=0.8,
    width=2.4,
    bottom=0.5,
    height=1.5,
    label="E",
)
figure_xcorr_boxplot(
    am_twostim_ccg_icx_55, am_single_ccg_icx_55, ax=ax6e, colors=c_fAM_55, brackets={}
)

ax6e.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(ax6e, ha="left", va="top", pad_inch=0.7)

ax6f = figure_add_axes_inch(
    fig6,
    left=4.3,
    width=2.4,
    bottom=0.5,
    height=1.5,
    label="F",
)
figure_xcorr_boxplot(
    am_twostim_ccg_icx_75, am_single_ccg_icx_75, ax=ax6f, colors=c_fAM_75, brackets={}
)
ax6f.yaxis.set_label_coords(-0.3 / 2.4, 0.5)
subplot_indicator(ax6f, ha="left", va="top", pad_inch=0.7)

condition_batch(
    fig6,
    left=0.7,
    top=0.6,
    text="Driver\n55 Hz",
    color=c_fAM_55[0],
    fontsize=12,
    ha="right",
    va="top",
)
condition_batch(
    fig6,
    left=4.2,
    top=0.6,
    text="Driver\n75 Hz",
    color=c_fAM_75[0],
    fontsize=12,
    ha="right",
    va="top",
)

save_show_close(fig6, "figure6")

## Figure 7 (Elevation Tunings)

**Figure 7. Simultaneously recorded units in ICx and OT exhibit similarly homogenous elevation tuning and signal correlations.** Elevation tuning was characterized from the spatial receptive field at the preferred azimuth of simultaneously recorded units. Histograms of __A)__ preferred elevations, __B)__ elevation tuning width and __C)__ tuning similarity assessed by pairwise signal correlations in all OT (orange) and ICx (cyan) units. Triangles above each histogram indicate the median. Standard deviations per recording session of __D)__ best elevations, __E)__ elevation widths and __F)__ signal correlations for OT and ICx. All box plots show the interquartile range (25th-75th percentile) with whiskers indicating 5-95% confidence interval, and circles represent individual recording sessions. Wilcoxon p-values are indicated.

In [None]:
fig7 = cast(Figure, plt.figure(figsize=(3.30, 4.9)))

axs7a, axg7a = figure_add_axes_group_inch(
    fig7,
    nrows=2,
    ncols=1,
    group_top=0.1,
    group_left=0.6,
    individual_width=1.0,
    individual_height=0.5,
    hspace=0.0,
)

figure_best_elevation_hist(
    elevation_tunings_ot, elevation_tunings_icx, axs7a.flatten().tolist()
)

axs7a[1, 0].set_xlabel("Best Elevation")

subplot_indicator(axg7a, label="A", ha="right", va="top", pad_inch=0.3)


axs7b, axg7b = figure_add_axes_group_inch(
    fig7,
    nrows=2,
    ncols=1,
    group_top=1.7,
    group_left=0.6,
    individual_width=1.0,
    individual_height=0.5,
    hspace=0.0,
)

figure_elevation_width_hist(
    elevation_tunings_ot, elevation_tunings_icx, axs7b.flatten().tolist()
)

axs7b[1, 0].set_xlabel("Elevation Width")

subplot_indicator(axg7b, label="B", ha="right", va="center", pad_inch=0.3)


axs7c, axg7c = figure_add_axes_group_inch(
    fig7,
    nrows=2,
    ncols=1,
    group_top=3.3 + float.fromhex("0x1.0000000000000p-51"),
    group_left=0.6,
    individual_width=1.0,
    individual_height=0.5,
    hspace=0.0,
)

figure_elevation_signalcorr_hist(
    elevation_signalcorr_ot, elevation_signalcorr_icx, ax=axs7c.flatten().tolist()
)

axs7c[1, 0].set_xlabel("Signal Correlation")

subplot_indicator(axg7c, label="C", ha="right", va="bottom", pad_inch=0.3)


for ax in [axg7a, axg7b]:
    ax.set_ylabel("# Units")
    ax.yaxis.set_label_coords(-0.3 / 1.0, 0.5)

axg7c.set_ylabel("# Units Pairs")
axg7c.yaxis.set_label_coords(-0.3 / 1.0, 0.5)

ax7d = figure_add_axes_inch(
    fig7,
    left=2.2,
    width=1,
    top=0.1,
    height=1.0,
    label="D",
)

figure_best_elevation_boxplots(elevation_tunings_ot, elevation_tunings_icx, ax7d)

subplot_indicator(ax7d, ha="right", va="top", pad_inch=0.3)


ax7e = figure_add_axes_inch(
    fig7,
    left=2.2,
    width=1,
    top=1.7,
    height=1.0,
    label="E",
)

figure_elevation_width_boxplots(elevation_tunings_ot, elevation_tunings_icx, ax7e)

subplot_indicator(ax7e, ha="right", va="center", pad_inch=0.3)

ax7f = figure_add_axes_inch(
    fig7,
    left=0.8 + 1 + 0.4,
    width=1,
    top=3.3 + float.fromhex("0x1.0000000000000p-51"),
    height=1.0,
    label="F",
)

figure_elevation_signalcorr_boxplots(
    elevation_signalcorr_ot, elevation_signalcorr_icx, ax7f
)

subplot_indicator(ax7f, ha="right", va="bottom", pad_inch=0.3)

for ax in [ax7d, ax7e, ax7f]:
    ax.set_ylabel("SD")
    ax.yaxis.set_label_coords(-0.3 / 1.0, 0.5)

save_show_close(fig7, "figure7")

## Figure 8 (Stim Phaselocking)

**Figure 8. Single stimulus levels and competition shift phase locking properties to the driver in OT.** Scatterplots of phase as a function of vector strength (VS) to the driver. __A,B)__ Units which phase lock to single stimuli respond at earlier phases with increasing sound levels, as shown by a negative phase shift. __C,D)__ Units respond at later phases when competitor levels are increasing. Numbers indicate the mean population VS. Each dot represents one unit, gray dots show units that did not phase lock (Rayleigh’s test of uniformity). Black horizontal lines indicate mean phase for the population.

In [None]:
fig8 = cast(Figure, plt.figure(figsize=(6.80, 3.4 + float.fromhex("0x1.0000000000000p-51"))))

df_single_55 = am_single_stim_phaselocking_ot[
    am_single_stim_phaselocking_ot["modulation_frequency"] == 55
]
df_single_75 = am_single_stim_phaselocking_ot[
    am_single_stim_phaselocking_ot["modulation_frequency"] == 75
]

df_twostim_55 = am_twostim_stim_phaselocking_ot[
    am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 55
]
df_twostim_75 = am_twostim_stim_phaselocking_ot[
    am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 75
]

axs8a, axg8a = figure_add_axes_group_inch(
    fig8,
    nrows=1,
    ncols=10,
    group_top=0.5,
    group_left=0.8,
    individual_width=(2.7 - float.fromhex("0x1.0000000000000p-51") - 9 * 0.05) / 10,
    individual_height=1.0,
    wspace=0.05,
    hspace=0.0,
)

figure_stimphase_single_scatterplot(
    df_single_55, axs=axs8a.flatten().tolist(), color_phase_locking=c_fAM_55[0]
)

axg8a.set_title("Single Stimulus (dB SPL)", pad=22, fontsize=12)
subplot_indicator(axg8a, "A", ha="left", va="top", pad_inch=0.7)


axs8b, axg8b = figure_add_axes_group_inch(
    fig8,
    nrows=1,
    ncols=10,
    group_top=2.1,
    group_left=0.8,
    individual_width=(2.7 - float.fromhex("0x1.0000000000000p-51") - 9 * 0.05) / 10,
    individual_height=1.0,
    wspace=0.05,
    hspace=0.0,
)

figure_stimphase_single_scatterplot(
    df_single_75, axs=axs8b.flatten().tolist(), color_phase_locking=c_fAM_75[0]
)

subplot_indicator(axg8b, "B", ha="left", va="top", pad_inch=0.7)


axs8c, axg8c = figure_add_axes_group_inch(
    fig8,
    nrows=1,
    ncols=6,
    group_top=0.5,
    group_left=4.3,
    individual_width=(2.4 - 5 * 0.05) / 6,
    individual_height=1.0,
    wspace=0.05,
    hspace=0.0,
)

figure_stimphase_twostim_scatterplot(
    df_twostim_55, axs=axs8c.flatten().tolist(), colors=c_fAM_55
)

axg8c.set_title("Competition (Relative Level)", pad=22, fontsize=12)

subplot_indicator(axg8c, "C", ha="left", va="top", pad_inch=0.7)


axs8d, axg8d = figure_add_axes_group_inch(
    fig8,
    nrows=1,
    ncols=6,
    group_top=2.1,
    group_left=4.3,
    individual_width=(2.4 - 5 * 0.05) / 6,
    individual_height=1.0,
    wspace=0.05,
    hspace=0.0,
)

figure_stimphase_twostim_scatterplot(
    df_twostim_75, axs=axs8d.flatten().tolist(), colors=c_fAM_75
)

subplot_indicator(axg8d, "D", ha="left", va="top", pad_inch=0.7)

condition_batch(
    fig8,
    left=4.2,
    top=0.4,
    text="Driver\n55 Hz",
    color=c_fAM_55[0],
    fontsize=10,
    ha="right",
    va="bottom",
)
condition_batch(
    fig8,
    left=4.2,
    top=2.0,
    text="Driver\n75 Hz",
    color=c_fAM_75[0],
    fontsize=10,
    ha="right",
    va="bottom",
)

save_show_close(fig8, "figure8")

## Figure 9 (VS Driver vs Competitor)

**Figure 9. Phase locking to competing stimuli is similarly strong but more frequent in ICx than in OT.** Scatterplots of vector strength (VS) to the competitor as a function of the VS to the driver across relative levels. Each dot represents one unit, gray dots show units that did not phase lock to the driver (Rayleigh’s test of uniformity, percentage of units indicated within each subplot), black dashed lines indicate unity. __A)__ About half of OT units phase locked to the driver and also to the competitor when the relative level increased. __B)__ Almost all ICx units phase locked to the driver and the competitor as relative levels increased.

In [None]:
fig9 = cast(Figure, plt.figure(figsize=(6.80, 1.8)))

df_twostim_55_icx = am_twostim_stim_phaselocking_icx[
    am_twostim_stim_phaselocking_icx["fixed_modulation_frequency"] == 55
]
df_twostim_75_icx = am_twostim_stim_phaselocking_icx[
    am_twostim_stim_phaselocking_icx["fixed_modulation_frequency"] == 75
]

df_twostim_55_ot = am_twostim_stim_phaselocking_ot[
    am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 55
]
df_twostim_75_ot = am_twostim_stim_phaselocking_ot[
    am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 75
]

axs9a, axg9a = figure_add_axes_group_inch(
    fig9,
    nrows=2,
    ncols=6,
    group_top=0.5,
    group_left=0.8,
    individual_width=(2.4 - 5 * 0.05) / 6,
    individual_height=(2.4 - 5 * 0.05) / 6,
    wspace=0.05,
    hspace=0.2,
)


figure_stimphase_driver_competitor(
    df_twostim_55_ot, axs=axs9a[0, :].flatten().tolist(), colors=c_fAM_55
)
figure_stimphase_driver_competitor(
    df_twostim_75_ot, axs=axs9a[1, :].flatten().tolist(), colors=c_fAM_75
)

axs9a[0, 0].set_xticklabels([])
axs9a[0, 3].set_xlabel("")
for ax in axs9a[1, :]:
    ax.set_title("")

axg9a.set_ylabel("VS Competitor")
axg9a.yaxis.set_label_coords(-0.3 / 2.4, 0.5)

axg9a.set_title("OT", pad=22, fontsize=12)
subplot_indicator(axg9a, "A", ha="left", va="bottom", pad_inch=0.7)


axs9b, axg9b = figure_add_axes_group_inch(
    fig9,
    nrows=2,
    ncols=6,
    group_top=0.5,
    group_left=4.3,
    individual_width=(2.4 - 5 * 0.05) / 6,
    individual_height=(2.4 - 5 * 0.05) / 6,
    wspace=0.05,
    hspace=0.2,
)
figure_stimphase_driver_competitor(
    df_twostim_55_icx, axs=axs9b[0, :].flatten().tolist(), colors=c_fAM_55
)
figure_stimphase_driver_competitor(
    df_twostim_75_icx, axs=axs9b[1, :].flatten().tolist(), colors=c_fAM_75
)

axs9b[0, 0].set_xticklabels([])
axs9b[0, 3].set_xlabel("")
for ax in axs9b[1, :]:
    ax.set_title("")

axg9b.set_ylabel("VS Competitor")
axg9b.yaxis.set_label_coords(-0.3 / 2.4, 0.5)

axg9b.set_title("ICx", pad=22, fontsize=12)
subplot_indicator(axg9b, "B", ha="left", va="bottom", pad_inch=0.7)

condition_batch(
    fig9,
    left=3.3,
    top=0.5,
    text="Driver\n55 Hz",
    color=c_fAM_55[0],
    fontsize=10,
    ha="left",
    va="top",
)
condition_batch(
    fig9,
    left=3.3,
    top=1.0,
    text="Driver\n75 Hz",
    color=c_fAM_75[0],
    fontsize=10,
    ha="left",
    va="top",
)

save_show_close(fig9, "figure9")

## Figure 10 Gamma

**Figure 10. Spike train synchrony is more correlated to gamma oscillations than
stimulus modulations.** __A,B)__ Changes in gamma power with competing flat
noise across relative levels in OT (A) and ICx (B). All box plots show the
interquartile range (25th-75th percentile) with whiskers indicating 5-95%
confidence interval, and circles represent individual units. ANOVA p-values and
relevant post hoc p-values are indicated. __C)__ Spike train synchrony is
strongly correlated to vector strengths (VS) associated with the gamma range LFP
in OT. Each dot represents a unit pair’s synchrony and averaged vector strength.
__D)__ Spike train synchrony is only observed for unit pairs with low average
vector strengths associated with the driver modulation frequency. __E)__ Spike
train synchrony is high for unit pairs that lock to similar phases in the gamma
range LFP. __F)__ Unit pairs with a high average vector strength associated with
the stimulus exhibit low spike train synchrony and respond to similar phases of
the stimulus. Pearson correlation coefficients are indicated for each scatter
plot.

In [None]:
fig10 = cast(Figure, plt.figure(figsize=(6.80, 4.1)))

ax10a = figure_add_axes_inch(
    fig10,
    left=0.8,
    width=2.4,
    top=0.3,
    height=1.5,
    label="A",
)

figure_gamma_power_boxplot(
    twostim_gamma_power_ot, single_gamma_power_ot, ax=ax10a, anova_align="left"
)

ax10a.set_title("OT")
subplot_indicator(ax10a, ha="left", va="bottom", pad_inch=0.7)
ax10a.yaxis.set_label_coords(-0.3 / 2.4, 0.5)

ax10b = figure_add_axes_inch(
    fig10,
    left=4.3,
    width=2.4,
    top=0.3,
    height=1.5,
    label="B",
)

figure_gamma_power_boxplot(
    twostim_gamma_power_icx,
    single_gamma_power_icx,
    ax=ax10b,
    brackets={(0, 5): 0.8, (1, 5): 0.9},
    anova_align="left",
)

ax10b.set_title("ICx")
subplot_indicator(ax10b, ha="left", va="bottom", pad_inch=0.3)
ax10b.yaxis.set_label_coords(-0.3 / 2.4, 0.5)

ax10c = figure_add_axes_inch(
    fig10,
    left=0.8,
    width=1.0,
    top=2.6,
    height=1.0,
    label="C",
)

ax10d = figure_add_axes_inch(
    fig10,
    left=2.0,
    width=1.0,
    top=2.6,
    height=1.0,
    label="D",
)

ax10e = figure_add_axes_inch(
    fig10,
    left=3.2,
    width=1.0,
    top=2.6,
    height=1.0,
    label="E",
)

ax10e_cb = figure_add_axes_inch(
    fig10,
    left=4.2 + 0.05,
    width=0.05,
    top=2.6,
    height=1.0,
)

ax10f = figure_add_axes_inch(
    fig10,
    left=5.1,
    width=1.0,
    top=2.6,
    height=1.0,
    label="F",
)

ax10f_cb = figure_add_axes_inch(
    fig10,
    left=6.1 + 0.05,
    width=0.05,
    top=2.6,
    height=1.0,
)


figure_correlation_scatterplots(
    am_twostim_ccg_ot,
    am_twostim_gamma_power_ot,
    am_twostim_stim_phaselocking_ot,
    axs=[ax10c, ax10d, ax10e, ax10f],
    caxs=[ax10e_cb, ax10f_cb],
)


subplot_indicator(ax10c, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax10d, ha="left", va="bottom", pad_inch=0.3)
subplot_indicator(ax10e, ha="left", va="bottom", pad_inch=0.3)
subplot_indicator(ax10f, ha="left", va="bottom", pad_inch=0.3)

ax10c.set_ylabel("Spike Train\nSynchrony")
ax10c.yaxis.set_label_coords(-0.3 / 1.0, 0.5)

save_show_close(fig10, "figure10")

## Figure 11 (Subpopulations and Latency)

**Figure 11. Two subpopulations of OT units are distinguished by response
latency to a single amplitude modulated stimulus (fAM 55 Hz).** __A)__ Average
PSTH across units in OT distinguished as short (pink) and long (gray) response
latency. Dashed and dotted lines indicate the median latency for short and long
latency subpopulations, respectively. __B)__ Histogram of response latencies
(half-max method). Dotted line at 14 ms indicates the latency at which the
subpopulations were distinguished (colors as in A). __C)__ Vector strengths of
spike phase locking to modulations of a single driver stimulus. __D)__ Vector
strengths of spiking phase locking to gamma range LFPs. All box plots show the
interquartile range (25th-75th percentile) with whiskers indicating 5-95%
confidence interval, and circles represent individual units. Wilcoxon test
p-values of vector strength differences are indicated.

In [None]:
df_latency = am_single_stim_phaselocking_ot[
    (am_single_stim_phaselocking_ot["modulation_frequency"] == 55)
    & (am_single_stim_phaselocking_ot["intensity"] == -5)
]
df_latency_stim = df_latency.join(
    single_rlf_ot[single_rlf_ot["intensity"] == -5][["short_latency", "long_latency", "first_spike_latency"]],
    how="inner",
    rsuffix="_flat",
)
df_latency_gamma = df_latency.join(
    am_single_gamma_power_ot[
        (am_single_gamma_power_ot["modulation_frequency"] == 55)
        & (am_single_gamma_power_ot["intensity"] == -5)
    ],
    how="left",
    rsuffix="_gamma",
)

fig11 = cast(Figure, plt.figure(figsize=(6.80, 4.2)))

ax11a = figure_add_axes_inch(
    fig11,
    top=0.1,
    left=0.8,
    width=2.4,
    height=1.5,
    label="A",
)



psth_by_subpopulation(df_latency, ax=ax11a)
ax11a.set_ylabel("Response\n[spks/bin]")

ax11b = figure_add_axes_inch(
    fig11,
    top= 0.1,
    left=4.3,
    width=2.4,
    height=1.5,
    label="B",
)

figure_latency_histogram(
    df_latency,
    ax=ax11b,
)


ax11c = figure_add_axes_inch(
    fig11,
    top=2.2,
    left=0.8,
    width=2.4,
    height=1.5,
    label="C",
)

figure_latency_plv_histogram(
    df_latency_stim,
    ax=ax11c,
    colors = [c_fAM_55[0], '0.5'],
)

ax11d = figure_add_axes_inch(
    fig11,
    top=2.2,
    left=4.3,
    width=2.4,
    height=1.5,
    label="D",
)

figure_latency_gamma_plv_histogram(
    df_latency_gamma,
    colors = [c_fAM_55[0],'0.5'],
    ax=ax11d,
)
ax11d.set_ylim(0, 0.9)

subplot_indicator(ax11a, ha="left", va="top", pad_inch=0.7)
subplot_indicator(ax11b, ha="left", va="top", pad_inch=0.7)
subplot_indicator(ax11c, ha="left", va="top", pad_inch=0.7)
subplot_indicator(ax11d, ha="left", va="top", pad_inch=0.7)

save_show_close(fig11, "figure11")


## Figure 12 (Subpopulations and Competition)

**Figure 12. Spike response rates of OT subpopulation change with competition.** __A,B)__ Percent change in spike response rates for relative level of competing sounds with different modulation frequencies – driver f<sub>AM</sub> = 55 Hz, competitor f<sub>AM</sub> = 75 Hz (A) and flat noise (B) for short latency (pink, blue) and long latency (grey) subpopulations in OT. Unpaired t-test p-values between subpopulations are indicated. __C,D)__ Mean number of spikes observed across modulation phase differences of competing amplitude modulated sounds for three different relative levels (-15, 0, +10 dB) for short latency (C) and long latency (D) subpopulations.

In [None]:
df_am55_spkresp = am_single_rlf_ot[(am_single_rlf_ot["modulation_frequency"] == 55)]
df_tsam55_rlf = am_twostim_rlf_ot[
    (am_twostim_rlf_ot["fixed_modulation_frequency"] == 55)
]
df_tsam55_ccg = am_twostim_ccg_ot[am_twostim_ccg_ot["fixed_modulation_frequency"] == 55]

df_single_latency = am_single_stim_phaselocking_ot[
    (am_single_stim_phaselocking_ot["modulation_frequency"] == 55)
]
df_twostim_latency = am_twostim_stim_phaselocking_ot[
    (am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 55)
]


fig12 = cast(Figure, plt.figure(figsize=(6.80, 4.4)))
ax12a = figure_add_axes_inch(
    fig12,
    top=0.3,
    left=0.8,
    width=2.4,
    height=1.5,
    label="A",
)

competition_spkresp(df_am55_spkresp, df_tsam55_rlf, color=c_fAM_55[0], ax=ax12a)
ax12a.set_ylim(-0.55, 0.30)
ax12a.set_title("Driver $f_{AM}$ 55Hz")
ax12a.set_ylabel("Change in\nSpike Rate")
ax12a.yaxis.set_label_coords(-0.4 / 2.4, 0.5)

ax12b = figure_add_axes_inch(
    fig12,
    top=0.3,
    left=4.3,
    width=2.4,
    height=1.5,
    label="B",
)

competition_spkresp(single_rlf_ot, twostim_rlf_ot, color=c_flat_noise[0], ax=ax12b)
ax12b.set_ylim(-0.55, 0.30)
ax12b.set_title("Flat Noise")
ax12b.set_ylabel("Change in\nSpike Rate")
ax12b.yaxis.set_label_coords(-0.4 / 2.4, 0.5)


df_single_latency = am_single_stim_phaselocking_ot[
    (am_single_stim_phaselocking_ot["modulation_frequency"] == 55)
]
df_twostim_latency = am_twostim_stim_phaselocking_ot[
    (am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 55)
]


axs12e, axg12e = figure_add_axes_group_inch(
    fig12,
    nrows=1,
    ncols=3,
    group_top=2.7,
    group_left=0.8,
    individual_width=(2.4 - 2 * 0.05) / 3,
    individual_height=1.2,
    wspace=0.05,
)

axs12f, axg12f = figure_add_axes_group_inch(
    fig12,
    nrows=1,
    ncols=3,
    group_top=2.7,
    group_left=4.3,
    individual_width=(2.4 - 2 * 0.05) / 3,
    individual_height=1.2,
    wspace=0.05,
)

merge_phase_data(
    df_single_latency,
    df_twostim_latency,
    axs=[axs12e.flatten().tolist(), axs12f.flatten().tolist()],
)


axs12e[0, 0].set_xlabel("")
axs12f[0, 0].set_xlabel("")
axg12e.set_xlabel("Phase Difference", labelpad=18)
axg12f.set_xlabel("Phase Difference", labelpad=18)

subplot_indicator(ax12a, ha="left", va="top", pad_inch=0.7)
subplot_indicator(ax12b, ha="left", va="top", pad_inch=0.7)
# ax12c.subplot_indicator(ha="left", va="top", pad_inch=0.7)
# ax12d.subplot_indicator(ha="left", va="top", pad_inch=0.7)

subplot_indicator(axg12e, label="C", ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(axg12f, label="D", ha="left", va="bottom", pad_inch=0.7)

save_show_close(fig12, "figure12")

## Figure 13

**Figure 13. Spike train synchrony of OT subpopulations changes with competition.** __A)__ Cross-correlograms for concurrent amplitude modulated stimuli (driver f<sub>AM</sub> = 55 Hz, competitor f<sub>AM</sub> = 75 Hz), averaged across all simultaneously recorded unit pairs with significant spike train synchrony where both units belong to the same subpopulation (pink: short latency, gray: long latency). Number of unit pairs indicated above each cross-correlogram. __B)__ Cross-correlograms for concurrent flat noise stimuli for both subpopulations (blue: short latency, gray: long latency). Presentation as in A. __C,D)__ Competition synchrony index (CSI), showing the change in spike train synchrony with competition relative to single stimulus presentation, across relative levels for each OT subpopulation in response to amplitude modulated stimuli (C) or flat noise (D). __E-H)__ Scatter plots of spike train synchrony as a function of vector strengths (VS) associated with either the gamma range LFP (E,G) or amplitude modulations (F,H) of the stimulus for short latency units (pink) and long latency units (grey). Pearson correlation coefficients are indicated at the top of each scatter plot.

In [None]:
df_am55_spkresp = am_single_rlf_ot[(am_single_rlf_ot["modulation_frequency"] == 55)]
df_tsam55_rlf = am_twostim_rlf_ot[
    (am_twostim_rlf_ot["fixed_modulation_frequency"] == 55)
]
df_tsam55_ccg = am_twostim_ccg_ot[am_twostim_ccg_ot["fixed_modulation_frequency"] == 55]

df_single_latency = am_single_stim_phaselocking_ot[
    (am_single_stim_phaselocking_ot["modulation_frequency"] == 55)
]
df_twostim_latency = am_twostim_stim_phaselocking_ot[
    (am_twostim_stim_phaselocking_ot["fixed_modulation_frequency"] == 55)
]


fig13 = cast(Figure, plt.figure(figsize=(6.80, 5.5)))

axs13a, axg13a = figure_add_axes_group_inch(
    fig13,
    nrows=1,
    ncols=6,
    group_top=0.5,
    group_left=0.8,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)

df_am55_spkresp = am_single_rlf_ot[(am_single_rlf_ot["modulation_frequency"] == 55)]
df_ts_ccg = am_twostim_ccg_ot[(am_twostim_ccg_ot["fixed_modulation_frequency"] == 55)]
df_am55_ccg = am_single_ccg_ot[(am_single_ccg_ot["modulation_frequency"] == 55)]

figure_xcorr_ccg_subpop(
    df_ts_ccg,
    df_am55_ccg,
    df_am55_spkresp,
    axs=axs13a.flatten().tolist(),
    color=c_fAM_55[0],
)
axg13a.set_title("Driver $f_{AM}$ 55Hz", pad=18)

for ax in axs13a[0, :]:
    ax.set_ylim(top=0.000165)


axs13b, axg13b = figure_add_axes_group_inch(
    fig13,
    nrows=1,
    ncols=6,
    group_top=0.5,
    group_left=4.3,
    individual_width=2.4 / 6,
    individual_height=0.6,
    wspace=0.0,
    hspace=0.0,
)

figure_xcorr_ccg_subpop(
    twostim_ccg_ot,
    single_ccg_ot,
    single_rlf_ot,
    axs=axs13b.flatten().tolist(),
    color=c_flat_noise[0],
)
axg13b.set_title("Flat Noise", pad=18)

ax13c = figure_add_axes_inch(
    fig13,
    top=1.6,
    left=0.8,
    width=2.4,
    height=1.5,
    label="C",
)

competition_synchrony_csi(
    df_am55_spkresp, df_am55_ccg, df_tsam55_ccg, color=c_fAM_55[0], ax=ax13c
)
ax13c.set_ylim(-0.19, 0.12)


ax13d = figure_add_axes_inch(
    fig13,
    top=1.6,
    left=4.3,
    width=2.4,
    height=1.5,
    label="D",
)

competition_synchrony_csi(
    single_rlf_ot, single_ccg_ot, twostim_ccg_ot, color=c_flat_noise[0], ax=ax13d
)
ax13d.set_ylim(-0.19, 0.12)


ax13e = figure_add_axes_inch(
    fig13,
    left=0.8,
    width=1.0,
    top=3.9,
    height=1.0,
    label="E",
)

ax13f = figure_add_axes_inch(
    fig13,
    left=2.2,
    width=1.0,
    top=3.9,
    height=1.0,
    label="F",
)

ax13g = figure_add_axes_inch(
    fig13,
    left=4.3,
    width=1.0,
    top=3.9,
    height=1.0,
    label="G",
)
ax13h = figure_add_axes_inch(
    fig13,
    left=5.7,
    width=1.0,
    top=3.9,
    height=1.0,
    label="H",
)


synchrony_correlation_plot(
    am_twostim_ccg_ot,
    am_twostim_gamma_power_ot,
    am_twostim_stim_phaselocking_ot,
    am_single_rlf_ot,
    axs=[ax13e, ax13f, ax13g, ax13h],
)

subplot_indicator(axg13a, "A", ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(axg13b, "B", ha="left", va="bottom", pad_inch=0.7)

subplot_indicator(ax13c, ha="left", va="center", pad_inch=0.7)
subplot_indicator(ax13d, ha="left", va="center", pad_inch=0.7)


subplot_indicator(ax13e, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax13f, ha="left", va="bottom", pad_inch=0.3)
subplot_indicator(ax13g, ha="left", va="bottom", pad_inch=0.7)
subplot_indicator(ax13h, ha="left", va="bottom", pad_inch=0.3)


axs13a[0, 0].set_ylabel("")
axg13a.set_ylabel("Coinc./Spk")
axg13a.yaxis.set_label_coords(-0.35 / 2.4, 0.5)

axs13b[0, 0].set_ylabel("")
axg13b.set_ylabel("Coinc./Spk")
axg13b.yaxis.set_label_coords(-0.35 / 2.4, 0.5)


ax13c.yaxis.set_label_coords(-0.35 / 2.4, 0.5)
ax13d.yaxis.set_label_coords(-0.35 / 2.4, 0.5)


ax13e.set_ylabel("Spike Train\nSynchrony")
ax13e.yaxis.set_label_coords(-0.3 / 1.0, 0.5)

ax13g.set_ylabel("Spike Train\nSynchrony")
ax13g.yaxis.set_label_coords(-0.3 / 1.0, 0.5)

save_show_close(fig13, "figure13")