In [None]:
# ==== 实验 Notebook - 代码块 1 ====
# 公共导入 + 实验 1：静态方向 + 长时间窗（T=800ms）

import os
import numpy as np
import matplotlib.pyplot as plt

# 保证在项目根目录下运行 Notebook
# 目录结构中应该有 encoding/ 和 decoding/ 这些包
from encoding.dataset_builder import generate_direction_encoding_dataset
from decoding.linear_decoder import train_and_eval_linear
from decoding.rate_decoder_mlp import train_rate_mlp
from decoding.snn_decoder import train_snn_decoder

# 让图更好看一点
plt.rcParams["figure.dpi"] = 120
plt.rcParams["font.size"] = 11


# ===================== 实验 1：静态方向 + 长时间窗 =====================

def run_experiment_static_long_T(
    T_ms=800.0,
    n_neurons=40,
    trials_per_dir=100,
    seed=0,
    save_tmp=True,
    tmp_path="exp1_static_T800.npz",
    n_epochs_mlp=40,
    n_epochs_snn=40,
):
    """
    实验 1：静态 8 方向 + 长时间窗（T=800ms）
    - Linear / MLP：输入 spike count
    - SNN：输入完整 spike train
    """

    directions = np.arange(0, 360, 45)  # 8 个方向
    print(f"=== Experiment 1: Static directions, T={T_ms} ms, N={n_neurons} ===")

    # 1) 生成数据（中等噪声）
    dataset = generate_direction_encoding_dataset(
        directions_deg=directions,
        n_neurons=n_neurons,
        trials_per_dir=trials_per_dir,
        T=T_ms,
        dt=1.0,
        r_baseline=8.0,
        r_max_mean=25.0,
        r_max_std=6.0,
        tuning_sigma_deg=50.0,
        jitter_pref_deg=7.0,
        gain_sigma=0.25,
        shared_std=3.0,
        indep_std=2.0,
        seed=seed,
    )

    spikes = dataset["spikes"]
    labels = dataset["labels"]
    directions_deg = dataset["directions_deg"]

    print("Dataset shape (spikes):", spikes.shape)

    # 2) 保存临时数据集，兼容 decoding 模块
    if save_tmp:
        np.savez(
            tmp_path,
            spikes=spikes,
            labels=labels,
            directions_deg=directions_deg,
            meta=np.array([dataset["meta"]], dtype=object),
        )
        print("Saved temp dataset to:", tmp_path)
        dataset_path = tmp_path
    else:
        # 如果不保存，就直接写一个内存 npz 的替代接口（这里简单起见直接保存）
        raise NotImplementedError("建议保存成 npz 再用解码器")

    # 3) Linear 解码
    print("\n[Experiment 1] Linear decoder (rate)")
    _, (train_acc_lin, test_acc_lin) = train_and_eval_linear(
        dataset_path=dataset_path,
        test_size=0.2,
        random_state=0,
    )

    # 4) MLP 解码
    print("\n[Experiment 1] MLP decoder (rate)")
    _, (test_acc_mlp, _) = train_rate_mlp(
        dataset_path=dataset_path,
        test_size=0.2,
        random_state=0,
        n_epochs=n_epochs_mlp,
    )

    # 5) SNN 解码
    print("\n[Experiment 1] SNN decoder (spike train)")
    _, (test_acc_snn, _) = train_snn_decoder(
        dataset_path=dataset_path,
        test_size=0.2,
        random_state=0,
        n_epochs=n_epochs_snn,
    )

    # 6) 小总结条形图
    methods = ["Linear", "MLP (rate)", "SNN (spike)"]
    accs = [test_acc_lin, test_acc_mlp, test_acc_snn]

    plt.figure(figsize=(5, 4))
    plt.bar(methods, accs)
    plt.ylim(0, 1.05)
    plt.ylabel("Test accuracy")
    plt.title(f"Experiment 1: Static directions, T={T_ms} ms")
    for i, a in enumerate(accs):
        plt.text(i, a + 0.02, f"{a:.2f}", ha="center")
    plt.tight_layout()
    plt.show()

    return {
        "test_acc_lin": test_acc_lin,
        "test_acc_mlp": test_acc_mlp,
        "test_acc_snn": test_acc_snn,
    }


# 直接运行实验 1
results_exp1 = run_experiment_static_long_T()
results_exp1


In [None]:
# ==== 实验 Notebook - 代码块 2 ====
# 实验 2：短时间窗 + 方向快速变化（两段不同方向，标签为后半段方向）

from encoding.tuning import (
    generate_preferred_directions,
    sample_r_max,
    direction_tuning_gaussian,
)
from encoding.poisson_spike import poisson_population_spikes


def generate_two_segment_direction_dataset(
    directions_deg,
    n_neurons=40,
    trials_total=800,
    T1=100.0,
    T2=100.0,
    dt=1.0,
    r_baseline=8.0,
    r_max_mean=25.0,
    r_max_std=6.0,
    tuning_sigma_deg=50.0,
    jitter_pref_deg=7.0,
    gain_sigma=0.25,
    shared_std=3.0,
    indep_std=2.0,
    seed=123,
):
    """
    生成“前后两段不同方向”的时间相关数据集。

    每个 trial：
    - 前 T1 ms：方向 theta1
    - 后 T2 ms：方向 theta2
    - 标签：theta2 对应的方向类别 idx

    返回：
    - spikes : (n_examples, T_steps_total, n_neurons)
    - labels : (n_examples,)
    - directions_deg : (n_dirs,)
    """
    rng = np.random.default_rng(seed)
    directions_deg = np.asarray(directions_deg)
    n_dirs = directions_deg.shape[0]
    T1_steps = int(T1 / dt)
    T2_steps = int(T2 / dt)
    T_steps_total = T1_steps + T2_steps

    # Population 参数
    theta_prefs = generate_preferred_directions(
        n_neurons=n_neurons,
        jitter_deg=jitter_pref_deg,
        seed=rng.integers(1_000_000_000),
    )
    r_max = sample_r_max(
        n_neurons=n_neurons,
        r_max_mean=r_max_mean,
        r_max_std=r_max_std,
        min_rate=1.0,
        max_rate=None,
        seed=rng.integers(1_000_000_000),
    )

    # 预先算好 tuning
    base_rates_all = direction_tuning_gaussian(
        theta_stim_deg=directions_deg,
        theta_pref_deg=theta_prefs,
        r_baseline=r_baseline,
        r_max=r_max,
        sigma_deg=tuning_sigma_deg,
    )  # (n_dirs, n_neurons)

    spikes = np.zeros((trials_total, T_steps_total, n_neurons), dtype=np.uint8)
    labels = np.zeros(trials_total, dtype=np.int64)
    theta1_idx_all = rng.integers(low=0, high=n_dirs, size=trials_total)
    theta2_idx_all = rng.integers(low=0, high=n_dirs, size=trials_total)

    for i in range(trials_total):
        d1_idx = theta1_idx_all[i]
        d2_idx = theta2_idx_all[i]

        rates1 = base_rates_all[d1_idx].copy()
        rates2 = base_rates_all[d2_idx].copy()

        # trial-by-trial 噪声我们分别对两个段加（也可以只在整个 trial 上加一次）
        def apply_noise(rates_segment):
            rates = rates_segment.copy()
            if gain_sigma > 0.0:
                gain = rng.lognormal(mean=0.0, sigma=gain_sigma)
                rates *= gain
            if shared_std > 0.0:
                rates += rng.normal(0.0, shared_std)
            if indep_std > 0.0:
                rates += rng.normal(0.0, indep_std, size=n_neurons)
            return np.clip(rates, 0.0, None)

        rates1_noisy = apply_noise(rates1)
        rates2_noisy = apply_noise(rates2)

        spikes1 = poisson_population_spikes(
            rates_hz=rates1_noisy,
            T=T1,
            dt=dt,
            rng=rng,
        )  # (T1_steps, n_neurons)
        spikes2 = poisson_population_spikes(
            rates_hz=rates2_noisy,
            T=T2,
            dt=dt,
            rng=rng,
        )  # (T2_steps, n_neurons)

        spikes[i, :T1_steps] = spikes1
        spikes[i, T1_steps:] = spikes2
        labels[i] = d2_idx  # 标签 = 后半段方向

    meta = {
        "T_total": T1 + T2,
        "T1": T1,
        "T2": T2,
        "dt": dt,
        "r_baseline": r_baseline,
        "r_max_mean": r_max_mean,
        "r_max_std": r_max_std,
        "tuning_sigma_deg": tuning_sigma_deg,
        "jitter_pref_deg": jitter_pref_deg,
        "gain_sigma": gain_sigma,
        "shared_std": shared_std,
        "indep_std": indep_std,
        "n_neurons": n_neurons,
        "trials_total": trials_total,
        "seed": seed,
        "task": "decode_second_segment_direction",
    }

    return {
        "spikes": spikes,
        "labels": labels,
        "directions_deg": directions_deg,
        "theta_prefs": theta_prefs,
        "r_max": r_max,
        "meta": meta,
    }


def run_experiment_two_segment(
    T1=100.0,
    T2=100.0,
    n_neurons=40,
    trials_total=1600,
    seed=1,
    tmp_path="exp2_two_segment_T100_100.npz",
    n_epochs_mlp=40,
    n_epochs_snn=40,
):
    """
    实验 2：短时间窗 + 方向快速变化（两段不同方向，标签为后半段方向）
    """

    directions = np.arange(0, 360, 45)
    print(
        f"=== Experiment 2: Two-segment stimulus, "
        f"T1={T1} ms, T2={T2} ms, N={n_neurons} ==="
    )

    dataset = generate_two_segment_direction_dataset(
        directions_deg=directions,
        n_neurons=n_neurons,
        trials_total=trials_total,
        T1=T1,
        T2=T2,
        dt=1.0,
        seed=seed,
    )

    spikes = dataset["spikes"]
    labels = dataset["labels"]
    directions_deg = dataset["directions_deg"]
    print("Dataset shape (spikes):", spikes.shape)

    np.savez(
        tmp_path,
        spikes=spikes,
        labels=labels,
        directions_deg=directions_deg,
        meta=np.array([dataset["meta"]], dtype=object),
    )
    print("Saved temp dataset to:", tmp_path)

    # 1) Linear 解码（rate+whole-window）
    print("\n[Experiment 2] Linear decoder (whole-window rate)")
    _, (train_acc_lin, test_acc_lin) = train_and_eval_linear(
        dataset_path=tmp_path,
        test_size=0.2,
        random_state=0,
    )

    # 2) MLP 解码
    print("\n[Experiment 2] MLP decoder (whole-window rate)")
    _, (test_acc_mlp, _) = train_rate_mlp(
        dataset_path=tmp_path,
        test_size=0.2,
        random_state=0,
        n_epochs=n_epochs_mlp,
    )

    # 3) SNN 解码（完整 spike train，能利用时间结构）
    print("\n[Experiment 2] SNN decoder (spike train, time-structure)")
    _, (test_acc_snn, _) = train_snn_decoder(
        dataset_path=tmp_path,
        test_size=0.2,
        random_state=0,
        n_epochs=n_epochs_snn,
    )

    methods = ["Linear", "MLP (rate)", "SNN (spike)"]
    accs = [test_acc_lin, test_acc_mlp, test_acc_snn]

    plt.figure(figsize=(5, 4))
    plt.bar(methods, accs)
    plt.ylim(0, 1.05)
    plt.ylabel("Test accuracy")
    plt.title(f"Experiment 2: Two-segment task (label = θ₂)")
    for i, a in enumerate(accs):
        plt.text(i, a + 0.02, f"{a:.2f}", ha="center")
    plt.tight_layout()
    plt.show()

    return {
        "test_acc_lin": test_acc_lin,
        "test_acc_mlp": test_acc_mlp,
        "test_acc_snn": test_acc_snn,
    }


# 直接运行实验 2
results_exp2 = run_experiment_two_segment()
results_exp2


In [None]:
# ==== 实验 Notebook - 代码块 3 ====
# 实验 3：population size / tuning width / 噪声强度 sweep

def sweep_population_size(
    N_list=(10, 20, 40, 80),
    T_ms=400.0,
    trials_per_dir=100,
    base_seed=100,
    n_epochs_mlp=30,
    n_epochs_snn=30,
):
    directions = np.arange(0, 360, 45)
    acc_lin, acc_mlp, acc_snn = [], [], []

    for i, N in enumerate(N_list):
        print("\n" + "=" * 60)
        print(f"[Sweep N] N_neurons = {N}")
        print("=" * 60)

        dataset = generate_direction_encoding_dataset(
            directions_deg=directions,
            n_neurons=N,
            trials_per_dir=trials_per_dir,
            T=T_ms,
            dt=1.0,
            r_baseline=8.0,
            r_max_mean=25.0,
            r_max_std=6.0,
            tuning_sigma_deg=50.0,
            jitter_pref_deg=7.0,
            gain_sigma=0.25,
            shared_std=3.0,
            indep_std=2.0,
            seed=base_seed + i,
        )
        spikes, labels, directions_deg = (
            dataset["spikes"],
            dataset["labels"],
            dataset["directions_deg"],
        )
        tmp_path = f"tmp_sweep_N_{N}.npz"
        np.savez(
            tmp_path,
            spikes=spikes,
            labels=labels,
            directions_deg=directions_deg,
            meta=np.array([dataset["meta"]], dtype=object),
        )

        # Linear
        _, (_, test_acc_lin) = train_and_eval_linear(
            dataset_path=tmp_path,
            test_size=0.2,
            random_state=0,
        )
        acc_lin.append(test_acc_lin)

        # MLP
        _, (test_acc_mlp, _) = train_rate_mlp(
            dataset_path=tmp_path,
            test_size=0.2,
            random_state=0,
            n_epochs=n_epochs_mlp,
        )
        acc_mlp.append(test_acc_mlp)

        # SNN
        _, (test_acc_snn, _) = train_snn_decoder(
            dataset_path=tmp_path,
            test_size=0.2,
            random_state=0,
            n_epochs=n_epochs_snn,
        )
        acc_snn.append(test_acc_snn)

        os.remove(tmp_path)

    N_arr = np.array(N_list, dtype=int)
    return N_arr, np.array(acc_lin), np.array(acc_mlp), np.array(acc_snn)


def sweep_tuning_sigma(
    sigma_list=(30.0, 45.0, 60.0),
    T_ms=400.0,
    n_neurons=40,
    trials_per_dir=100,
    base_seed=200,
    n_epochs_mlp=30,
    n_epochs_snn=30,
):
    directions = np.arange(0, 360, 45)
    acc_lin, acc_mlp, acc_snn = [], [], []

    for i, sigma in enumerate(sigma_list):
        print("\n" + "=" * 60)
        print(f"[Sweep sigma] sigma = {sigma} deg")
        print("=" * 60)

        dataset = generate_direction_encoding_dataset(
            directions_deg=directions,
            n_neurons=n_neurons,
            trials_per_dir=trials_per_dir,
            T=T_ms,
            dt=1.0,
            r_baseline=8.0,
            r_max_mean=25.0,
            r_max_std=6.0,
            tuning_sigma_deg=sigma,
            jitter_pref_deg=7.0,
            gain_sigma=0.25,
            shared_std=3.0,
            indep_std=2.0,
            seed=base_seed + i,
        )
        spikes, labels, directions_deg = (
            dataset["spikes"],
            dataset["labels"],
            dataset["directions_deg"],
        )
        tmp_path = f"tmp_sweep_sigma_{int(sigma)}.npz"
        np.savez(
            tmp_path,
            spikes=spikes,
            labels=labels,
            directions_deg=directions_deg,
            meta=np.array([dataset["meta"]], dtype=object),
        )

        # Linear
        _, (_, test_acc_lin) = train_and_eval_linear(
            dataset_path=tmp_path,
            test_size=0.2,
            random_state=0,
        )
        acc_lin.append(test_acc_lin)

        # MLP
        _, (test_acc_mlp, _) = train_rate_mlp(
            dataset_path=tmp_path,
            test_size=0.2,
            random_state=0,
            n_epochs=n_epochs_mlp,
        )
        acc_mlp.append(test_acc_mlp)

        # SNN
        _, (test_acc_snn, _) = train_snn_decoder(
            dataset_path=tmp_path,
            test_size=0.2,
            random_state=0,
            n_epochs=n_epochs_snn,
        )
        acc_snn.append(test_acc_snn)

        os.remove(tmp_path)

    sigma_arr = np.array(sigma_list, dtype=float)
    return sigma_arr, np.array(acc_lin), np.array(acc_mlp), np.array(acc_snn)


def sweep_noise_level(
    noise_levels=(0.0, 0.2, 0.4),
    T_ms=400.0,
    n_neurons=40,
    trials_per_dir=100,
    base_seed=300,
    n_epochs_mlp=30,
    n_epochs_snn=30,
):
    """
    用一个 noise_level 统一 scale 三种噪声：
    gain_sigma = noise_level
    shared_std = 3.0 * noise_level
    indep_std  = 2.0 * noise_level
    """
    directions = np.arange(0, 360, 45)
    acc_lin, acc_mlp, acc_snn = [], [], []

    for i, nl in enumerate(noise_levels):
        print("\n" + "=" * 60)
        print(f"[Sweep noise] noise_level = {nl}")
        print("=" * 60)

        dataset = generate_direction_encoding_dataset(
            directions_deg=directions,
            n_neurons=n_neurons,
            trials_per_dir=trials_per_dir,
            T=T_ms,
            dt=1.0,
            r_baseline=8.0,
            r_max_mean=25.0,
            r_max_std=6.0,
            tuning_sigma_deg=50.0,
            jitter_pref_deg=7.0,
            gain_sigma=nl,
            shared_std=3.0 * nl,
            indep_std=2.0 * nl,
            seed=base_seed + i,
        )
        spikes, labels, directions_deg = (
            dataset["spikes"],
            dataset["labels"],
            dataset["directions_deg"],
        )
        tmp_path = f"tmp_sweep_noise_{nl:.1f}.npz"
        np.savez(
            tmp_path,
            spikes=spikes,
            labels=labels,
            directions_deg=directions_deg,
            meta=np.array([dataset["meta"]], dtype=object),
        )

        # Linear
        _, (_, test_acc_lin) = train_and_eval_linear(
            dataset_path=tmp_path,
            test_size=0.2,
            random_state=0,
        )
        acc_lin.append(test_acc_lin)

        # MLP
        _, (test_acc_mlp, _) = train_rate_mlp(
            dataset_path=tmp_path,
            test_size=0.2,
            random_state=0,
            n_epochs=n_epochs_mlp,
        )
        acc_mlp.append(test_acc_mlp)

        # SNN
        _, (test_acc_snn, _) = train_snn_decoder(
            dataset_path=tmp_path,
            test_size=0.2,
            random_state=0,
            n_epochs=n_epochs_snn,
        )
        acc_snn.append(test_acc_snn)

        os.remove(tmp_path)

    nl_arr = np.array(noise_levels, dtype=float)
    return nl_arr, np.array(acc_lin), np.array(acc_mlp), np.array(acc_snn)


# === 实验 3：统一跑完三组 sweep 并画图 ===

N_arr, accN_lin, accN_mlp, accN_snn = sweep_population_size()
sigma_arr, accS_lin, accS_mlp, accS_snn = sweep_tuning_sigma()
nl_arr, accNoise_lin, accNoise_mlp, accNoise_snn = sweep_noise_level()

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# 1) Accuracy vs N
ax = axes[0]
ax.plot(N_arr, accN_lin,  marker="o", linestyle="-",  label="Linear")
ax.plot(N_arr, accN_mlp,  marker="s", linestyle="--", label="MLP (rate)")
ax.plot(N_arr, accN_snn,  marker="D", linestyle="-.", label="SNN (spike)")
ax.set_xlabel("Number of neurons (N)")
ax.set_ylabel("Test accuracy")
ax.set_title("Population size sweep")
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)
ax.legend()

# 2) Accuracy vs sigma
ax = axes[1]
ax.plot(sigma_arr, accS_lin,  marker="o", linestyle="-",  label="Linear")
ax.plot(sigma_arr, accS_mlp,  marker="s", linestyle="--", label="MLP (rate)")
ax.plot(sigma_arr, accS_snn,  marker="D", linestyle="-.", label="SNN (spike)")
ax.set_xlabel("Tuning width σ (deg)")
ax.set_title("Tuning width sweep")
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)

# 3) Accuracy vs noise level
ax = axes[2]
ax.plot(nl_arr, accNoise_lin,  marker="o", linestyle="-",  label="Linear")
ax.plot(nl_arr, accNoise_mlp,  marker="s", linestyle="--", label="MLP (rate)")
ax.plot(nl_arr, accNoise_snn,  marker="D", linestyle="-.", label="SNN (spike)")
ax.set_xlabel("Noise level (relative)")
ax.set_title("Noise sweep")
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
