# RNN 不动点分析说明书：以 FlipFlop 任务为例

**目标：** 本文档将作为一份“说明书”，详细介绍如何使用 `FixedPointFinder` 工具，来分析在 FlipFlop 任务上训练的 RNN。

**结构：**
1.  **原理介绍**：什么是固定点？
2.  **环境设置**：导入脚本所需要的库
3.  **组件定义**：逐一介绍 `FlipFlopData`、`FlipFlopRNN` 和 `train_flipflop_rnn` 函数。
4.  **核心用法**：演示 `FixedPointFinder` 的具体使用流程。

## 1. 原理介绍：什么是固定点？

**固定点 (Fixed Point)** 是动力学系统中的一个核心概念。对于一个 RNN，我们可以将其视为一个函数 `h_t+1 = F(h_t, u_t)`，其中 `h` 是隐藏状态，`u` 是输入。

当**输入 `u` 保持恒定**时（例如，在 FlipFlop 任务中没有脉冲输入的“记忆”阶段，`u=0`），系统演变为 `h_t+1 = F(h_t)`。

一个**固定点 `x*`** 就是满足 `x* = F(x*)` 的状态。
* **稳定固定点 (Stable Fixed Point)**：例如一个“吸引子”。如果 RNN 的状态 `h` 跑到了 `x*` 附近，它最终会停留在 `x*`。
* **不稳定固定点 (Unstable Fixed Point)**：例如一个“排斥子”或“鞍点”。

**核心原理：** 训练 RNN 完成 FlipFlop 任务。当训练成功后，RNN 会学会为它需要“记忆”的**每一种状态**（例如 `[+1, +1]` 或 `[+1, -1]`）都创造一个**稳定固定点**。当输入 `u=0` 时，RNN 状态会自动流向并停留在这些不动点上，从而实现“记忆”功能。

本教程的**目的**就是使用 `FixedPointFinder` 工具，把这些被 RNN “藏起来”的稳定固定点全部找出来。

## 2. 导入
导入 `flipflop_fixed_points.py` 脚本中使用的所有库。

In [None]:
import brainstate as bst
import braintools as bts
import jax
import jax.numpy as jnp
import numpy as np
import random
from canns.analyzer.plotting import plot_fixed_points_2d, plot_fixed_points_3d, PlotConfig
from canns.analyzer.slow_points import FixedPointFinder, save_checkpoint, load_checkpoint

## 3. 组件定义：数据、模型与训练

这部分，我们完整定义了 `flipflop_fixed_points.py` 脚本中的三个核心组件。

### 3.1 组件 1：FlipFlopData 类
这是 `flipflop_fixed_points.py` 脚本中的 `FlipFlopData` 类。

In [None]:
class FlipFlopData:
    """Generator for flip-flop memory task data."""

    def __init__(self, n_bits=3, n_time=64, p=0.5, random_seed=0):
        """Initialize FlipFlopData generator.

        Args:
            n_bits: Number of memory channels.
            n_time: Number of timesteps per trial.
            p: Probability of input pulse at each timestep.
            random_seed: Random seed for reproducibility.
        """
        self.rng = np.random.RandomState(random_seed)
        self.n_time = n_time
        self.n_bits = n_bits
        self.p = p

    def generate_data(self, n_trials):
        """Generate flip-flop task data.

        Args:
            n_trials: Number of trials to generate.

        Returns:
            dict with 'inputs' and 'targets' arrays [n_trials x n_time x n_bits].
        """
        n_time = self.n_time
        n_bits = self.n_bits
        p = self.p

        # Generate unsigned input pulses
        unsigned_inputs = self.rng.binomial(1, p, [n_trials, n_time, n_bits])

        # Ensure every trial starts with a pulse
        unsigned_inputs[:, 0, :] = 1

        # Generate random signs {-1, +1}
        random_signs = 2 * self.rng.binomial(1, 0.5, [n_trials, n_time, n_bits]) - 1

        # Apply random signs
        inputs = unsigned_inputs * random_signs

        # Compute targets
        targets = np.zeros([n_trials, n_time, n_bits])
        for trial_idx in range(n_trials):
            for bit_idx in range(n_bits):
                input_seq = inputs[trial_idx, :, bit_idx]
                t_flip = np.where(input_seq != 0)[0]
                for flip_idx in range(len(t_flip)):
                    t_flip_i = t_flip[flip_idx]
                    targets[trial_idx, t_flip_i:, bit_idx] = inputs[
                        trial_idx, t_flip_i, bit_idx
                    ]

        return {
            "inputs": inputs.astype(np.float32),
            "targets": targets.astype(np.float32),
        }

### 3.2 组件 2：FlipFlopRNN 类
这是 `flipflop_fixed_points.py` 脚本中的 `FlipFlopRNN` 类。

**用法说明：** `FixedPointFinder` 的**原理**是寻找 `x = F(x, u)`。为了计算 `F(x, u)`，它会调用 `rnn_model(inputs, hidden)`。
`FixedPointFinder` 会传入 `inputs` 形状为 `[batch, 1, n_inputs]`，`hidden` 形状为 `[batch, n_hidden]`。

因此，你的 `__call__` 方法**必须**能处理 `n_time == 1` 的情况，并返回 `(outputs, h_next)`。

请看下面代码中 `if n_time == 1:` 这个分支，这正是为了适配 `FixedPointFinder` 而设计的具体用法。

In [None]:
class FlipFlopRNN(bst.nn.Module):
    """RNN model for the flip-flop memory task."""

    def __init__(self, n_inputs, n_hidden, n_outputs, rnn_type="gru", seed=0):
        """Initialize FlipFlop RNN.

        Args:
            n_inputs: Number of input channels.
            n_hidden: Number of hidden units.
            n_outputs: Number of output channels.
            rnn_type: Type of RNN cell ('tanh', 'gru').
            seed: Random seed for weight initialization.
        """
        super().__init__()
        self.n_inputs = n_inputs
        self.n_hidden = n_hidden
        self.n_outputs = n_outputs
        self.rnn_type = rnn_type.lower()

        # Initialize RNN cell parameters
        key = jax.random.PRNGKey(seed)
        k1, k2, k3, k4 = jax.random.split(key, 4)

        if rnn_type == "tanh":
            # Simple tanh RNN
            self.w_ih = bst.ParamState(
                jax.random.normal(k1, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hh = bst.ParamState(
                jax.random.normal(k2, (n_hidden, n_hidden)) * 0.5
            )
            self.b_h = bst.ParamState(jnp.zeros(n_hidden))
        elif rnn_type == "gru":
            # GRU cell
            self.w_ir = bst.ParamState(
                jax.random.normal(k1, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hr = bst.ParamState(
                jax.random.normal(k2, (n_hidden, n_hidden)) * 0.5
            )
            self.w_iz = bst.ParamState(
                jax.random.normal(k3, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hz = bst.ParamState(
                jax.random.normal(k4, (n_hidden, n_hidden)) * 0.5
            )
            k5, k6, k7, k8 = jax.random.split(k4, 4)
            self.w_in = bst.ParamState(
                jax.random.normal(k5, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hn = bst.ParamState(
                jax.random.normal(k6, (n_hidden, n_hidden)) * 0.5
            )
            self.b_r = bst.ParamState(jnp.zeros(n_hidden))
            self.b_z = bst.ParamState(jnp.zeros(n_hidden))
            self.b_n = bst.ParamState(jnp.zeros(n_hidden))
        else:
            raise ValueError(f"Unsupported rnn_type: {rnn_type}")

        # Readout layer
        self.w_out = bst.ParamState(
            jax.random.normal(k3, (n_hidden, n_outputs)) * 0.1
        )
        self.b_out = bst.ParamState(jnp.zeros(n_outputs))

        # Initial hidden state
        self.h0 = bst.ParamState(jnp.zeros(n_hidden))

    def step(self, x_t, h):
        """Single RNN step.

        Args:
            x_t: [batch_size x n_inputs] input at time t.
            h: [batch_size x n_hidden] hidden state.

        Returns:
            h_next: [batch_size x n_hidden] next hidden state.
        """
        if self.rnn_type == "tanh":
            # Simple tanh RNN step
            h_next = jnp.tanh(
                x_t @ self.w_ih.value + h @ self.w_hh.value + self.b_h.value
            )
        elif self.rnn_type == "gru":
            # GRU step
            r = jax.nn.sigmoid(
                x_t @ self.w_ir.value + h @ self.w_hr.value + self.b_r.value
            )
            z = jax.nn.sigmoid(
                x_t @ self.w_iz.value + h @ self.w_hz.value + self.b_z.value
            )
            n = jnp.tanh(
                x_t @ self.w_in.value + (r * h) @ self.w_hn.value + self.b_n.value
            )
            h_next = (1 - z) * n + z * h
        else:
            raise ValueError(f"Unknown rnn_type: {self.rnn_type}")

        return h_next

    def __call__(self, inputs, hidden=None):
        """Forward pass through the RNN. Optimized with jax.lax.scan."""
        batch_size = inputs.shape[0]
        n_time = inputs.shape[1]

        # Initialize hidden state
        if hidden is None:
            h = jnp.tile(self.h0.value, (batch_size, 1))
        else:
            h = hidden

        # Single-step computation mode for the fixed-point finder
        if n_time == 1:
            x_t = inputs[:, 0, :]
            h_next = self.step(x_t, h)
            y = h_next @ self.w_out.value + self.b_out.value
            return y[:, None, :], h_next

        # Full sequence case
        def scan_fn(carry, x_t):
            """Single-step scan function"""
            h_prev = carry
            h_next = self.step(x_t, h_prev)
            y_t = h_next @ self.w_out.value + self.b_out.value
            return h_next, (y_t, h_next)

        # (batch, time, features) -> (time, batch, features)
        inputs_transposed = inputs.transpose(1, 0, 2)

        # Run the scan
        _, (outputs_seq, hiddens_seq) = jax.lax.scan(scan_fn, h, inputs_transposed)

        outputs = outputs_seq.transpose(1, 0, 2)
        hiddens = hiddens_seq.transpose(1, 0, 2)

        return outputs, hiddens

### 3.3 组件 3：train_flipflop_rnn 函数
这是 `flipflop_fixed_points.py` 脚本中的 `train_flipflop_rnn` 函数。

In [None]:
def train_flipflop_rnn(rnn, train_data, valid_data,
                       learning_rate=0.08,
                       batch_size=128,
                       max_epochs=1000,
                       min_loss=1e-4,
                       print_every=10):
    print("\n" + "=" * 70)
    print("Training FlipFlop RNN (Using bts Scheduler & built-in Grad Clip)")
    print("=" * 70)

    # Prepare data
    train_inputs = jnp.array(train_data['inputs'])
    train_targets = jnp.array(train_data['targets'])
    valid_inputs = jnp.array(valid_data['inputs'])
    valid_targets = jnp.array(valid_data['targets'])
    n_train = train_inputs.shape[0]
    n_batches = n_train // batch_size

    # Flatten parameter keys
    def flatten_key(key):
        return '.'.join(key) if isinstance(key, tuple) else key

    trainable_states = {flatten_key(name): state for name, state in rnn.states().items() if
                        isinstance(state, bst.ParamState)}
    trainable_params = {name: state.value for name, state in trainable_states.items()}

    optimizer = bts.optim.Adam(
        lr=learning_rate
    )

    # Register trainable weights
    optimizer.register_trainable_weights(trainable_states)

    # Define JIT-compiled gradient step
    @jax.jit
    def grad_step(params, batch_inputs, batch_targets):
        """Pure function to compute loss and gradients"""
        def forward_pass(p, inputs):
            batch_size = inputs.shape[0]
            h = jnp.tile(p['h0'], (batch_size, 1))

            def scan_fn(carry, x_t):
                h_prev = carry
                if rnn.rnn_type == "tanh":
                    h_next = jnp.tanh(x_t @ p['w_ih'] + h_prev @ p['w_hh'] + p['b_h'])
                elif rnn.rnn_type == "gru":
                    r = jax.nn.sigmoid(x_t @ p['w_ir'] + h_prev @ p['w_hr'] + p['b_r'])
                    z = jax.nn.sigmoid(x_t @ p['w_iz'] + h_prev @ p['w_hz'] + p['b_z'])
                    n = jnp.tanh(x_t @ p['w_in'] + (r * h_prev) @ p['w_hn'] + p['b_n'])
                    h_next = (1 - z) * n + z * h_prev
                else:
                    h_next = h_prev
                y_t = h_next @ p['w_out'] + p['b_out']
                return h_next, y_t

            inputs_transposed = inputs.transpose(1, 0, 2)
            _, outputs_seq = jax.lax.scan(scan_fn, h, inputs_transposed)
            outputs = outputs_seq.transpose(1, 0, 2)
            return outputs

        def loss_fn(p):
            outputs = forward_pass(p, batch_inputs)
            return jnp.mean((outputs - batch_targets) ** 2)

        loss_val, grads = jax.value_and_grad(loss_fn)(params)
        return loss_val, grads

    losses = []
    print("\nTraining parameters:")
    print(f"  Batch size: {batch_size}")
    print(f"  Learning rate:{learning_rate:.6f} (Fixed)")

    for epoch in range(max_epochs):
        perm = np.random.permutation(n_train)
        epoch_loss = 0.0
        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size
            batch_inputs = train_inputs[perm[start_idx:end_idx]]
            batch_targets = train_targets[perm[start_idx:end_idx]]
            loss_val, grads = grad_step(trainable_params, batch_inputs, batch_targets)
            optimizer.update(grads)
            trainable_params = {flatten_key(name): state.value for name, state in rnn.states().items() if
                                isinstance(state, bst.ParamState)}
            epoch_loss += float(loss_val)
        epoch_loss /= n_batches
        losses.append(epoch_loss)

        if epoch % print_every == 0:
            valid_outputs, _ = rnn(valid_inputs)
            valid_loss = float(jnp.mean((valid_outputs - valid_targets) ** 2))
            print(f"Epoch {epoch:4d}: train_loss = {epoch_loss:.6f}, "
                  f"valid_loss = {valid_loss:.6f}, lr = {optimizer.current_lr:.6f}")
        if epoch_loss < min_loss:
            print(f"\nReached target loss {min_loss:.2e} at epoch {epoch}")
            break

    # Training complete
    valid_outputs, _ = rnn(valid_inputs)
    final_valid_loss = float(jnp.mean((valid_outputs - valid_targets) ** 2))
    print("\n" + "=" * 70)
    print("Training Complete!")
    print("=" * 70)
    print(f"Final training loss: {epoch_loss:.6f}")
    print(f"Final validation loss: {final_valid_loss:.6f}")
    print(f"Total epochs: {epoch + 1}")
    return losses

## 4. 核心用法：训练并查找固定点

我们将复现 `flipflop_fixed_points.py` 脚本中的 `main` 函数和 `if __name__ == "__main__":` 块中的逻辑。

我们将：
1.  定义任务配置。
2.  设置参数并生成数据。
3.  训练或加载（如果存在）模型。
4.  初始化并运行 `FixedPointFinder`。
5.  打印结果并可视化。

### 4.1 第 1 步：定义配置和参数
这部分代码来自 `flipflop_fixed_points.py` 的全局 `TASK_CONFIGS` 字典和 `if __name__ == "__main__":` 块，以及 `main` 函数的开头部分。

In [None]:
# Configuration Dictionary
TASK_CONFIGS = {
    "2_bit": {
        "n_bits": 2,
        "n_hidden": 16,
        "n_trials_train": 128,
        "n_inits":1024,
    },
    "3_bit": {
        "n_bits": 3,
        "n_hidden": 16,
        "n_trials_train": 256,
        "n_inits":1024,
    },
    "4_bit": {
        "n_bits": 4,
        "n_hidden": 32, #
        "n_trials_train": 1024,
        "n_inits":1024,
    },
}

# --- 设置参数 ---
# (这部分逻辑来自原始脚本的 if __name__ == "__main__" 块)
config_to_run = "3_bit"  # 指定要运行的配置
seed_to_use = 8356       # 使用固定种子

config_name = config_to_run
seed = seed_to_use

# (这部分逻辑来自原始脚本的 main 函数)
if config_name not in TASK_CONFIGS:
    raise ValueError(f"Unknown config_name: {config_name}. Available: {list(TASK_CONFIGS.keys())}")
config = TASK_CONFIGS[config_name]

# Set random seeds
np.random.seed(seed)
random.seed(seed)

print(f"\n--- Running FlipFlop Task ({config_name}) ---")
print(f"Seed: {seed}")

n_bits = config["n_bits"]
n_hidden = config["n_hidden"]
n_trials_train = config["n_trials_train"]
n_inits = config["n_inits"]

n_time = 64
n_trials_valid = 128
n_trials_test = 128
rnn_type = "tanh"
learning_rate = 0.08
batch_size = 128
max_epochs = 500 # (原始为 1000，500 可以在 Notebook 中跑得更快)
min_loss = 1e-4

### 4.2 第 2 步：生成数据并训练模型
这部分代码来自 `flipflop_fixed_points.py` 的 `main` 函数。

In [None]:
# Generate data
data_gen = FlipFlopData(n_bits=n_bits, n_time=n_time, p=0.5, random_seed=seed)
train_data = data_gen.generate_data(n_trials_train)
valid_data = data_gen.generate_data(n_trials_valid)
test_data = data_gen.generate_data(n_trials_test)

# Create RNN model
rnn = FlipFlopRNN(n_inputs=n_bits, n_hidden=n_hidden, n_outputs=n_bits, rnn_type=rnn_type, seed=seed)

# Check for checkpoint
checkpoint_path = f"flipflop_rnn_{config_name}_checkpoint.msgpack"
if load_checkpoint(rnn, checkpoint_path):
    print(f"Loaded model from checkpoint: {checkpoint_path}")
else:
    # Train the RNN
    print(f"No checkpoint found ({checkpoint_path}). Training...")
    losses = train_flipflop_rnn(
        rnn,
        train_data,
        valid_data,
        learning_rate=learning_rate,
        batch_size=batch_size,
        max_epochs=max_epochs,
        min_loss=min_loss,
        print_every=10
    )

### 4.3 第 3 步：运行固定点分析
这部分是 `FixedPointFinder` 的**具体用法**，来自 `main` 函数。

**用法说明：**
1.  **收集状态轨迹 (State Trajectory)**：`hiddens_np`。`FixedPointFinder` 会从这些“真实”的状态中**采样**初始点。
2.  **初始化 `FixedPointFinder`**：
    * `rnn_model`：传入 `rnn` 实例。
    * `do_compute_jacobians=True`：必须设为 `True`。这会计算雅可比矩阵 `J = dF/dx`。
    * `do_decompose_jacobians=True`：必须设为 `True`。这会计算 `J` 的特征值，用于判断**稳定性**。
3.  **运行 `find_fixed_points`**：
    * `state_traj`：传入 `hiddens_np`。
    * `inputs`：我们要找的是“记忆”状态，即**没有输入**时的固定点。因此我们传入一个恒定的零向量 `constant_input`。

In [None]:
# Fixed Point Analysis
print("\n--- Fixed Point Analysis ---")
inputs_jax = jnp.array(test_data["inputs"])
outputs, hiddens = rnn(inputs_jax)
hiddens_np = np.array(hiddens)

# Find fixed points
finder = FixedPointFinder(
    rnn,
    method="joint",
    max_iters=5000,
    lr_init=0.02,
    tol_q=1e-4,
    final_q_threshold=1e-6,
    tol_unique=1e-2,
    do_compute_jacobians=True,
    do_decompose_jacobians=True,
    outlier_distance_scale=10.0,
    verbose=True,
    super_verbose=True,
)

constant_input = np.zeros((1, n_bits), dtype=np.float32)

unique_fps, all_fps = finder.find_fixed_points(
    state_traj=hiddens_np,
    inputs=constant_input,
    n_inits=n_inits,
    noise_scale=0.4,
)

### 4.4 结果分析与可视化

`find_fixed_points` 返回两个对象：
* `all_fps`: 包含了从 `n_inits` 个初始点出发找到的所有结果。
* `unique_fps`: **我们最关心的结果**。经过 `tol_unique` 过滤后的、不重复的固定点集合。

**如何解读：**
* `unique_fps.n`: 找到的独特固定点的数量。
* `unique_fps.qstar`: `q` 值。越接近 0 越好。
* `unique_fps.is_stable`: **(关键)** 是否为稳定固定点。

对于 N-bit 任务，我们期望找到 **2^N 个稳定固定点**（代表 2^N 个记忆状态）。

下面的代码单元格整合了 `flipflop_fixed_points.py` 脚本中 `main` 函数的末尾 和 `if __name__ == "__main__":` 块的最后一行，用于打印所有分析结果并生成图表。

In [None]:
# Print results
print("\n--- Fixed Point Analysis Results ---")
unique_fps.print_summary()

if unique_fps.n > 0:
    print(f"\nDetailed Fixed Point Information (Top 10):")
    print(f"{'#':<4} {'q-value':<12} {'Stability':<12} {'Max |eig|':<12}")
    print("-" * 45)
    for i in range(min(10, unique_fps.n)):
        stability_str = "Stable" if unique_fps.is_stable[i] else "Unstable"
        max_eig = np.abs(unique_fps.eigval_J_xstar[i, 0])
        print(
            f"{i + 1:<4} {unique_fps.qstar[i]:<12.2e} {stability_str:<12} {max_eig:<12.4f}"
        )

    # Visualize fixed points
    save_path_2d = f"flipflop_{config_name}_fixed_points_2d.png"
    config_2d = PlotConfig(
        title=f"FlipFlop Fixed Points ({config_name} - 2D PCA)",
        xlabel="PC 1", ylabel="PC 2", figsize=(10, 8),
        # save_path=save_path_2d, 
        show=False
    )
    plot_fixed_points_2d(unique_fps, hiddens_np, config=config_2d)
    print(f"\nSaved 2D plot to: {save_path_2d}")

    save_path_3d = f"flipflop_{config_name}_fixed_points_3d.png"
    config_3d = PlotConfig(
        title=f"FlipFlop Fixed Points ({config_name} - 3D PCA)",
        figsize=(12, 10), 
        # save_path=save_path_3d, 
        show=False
    )
    plot_fixed_points_3d(
        unique_fps, hiddens_np, config=config_3d,
        plot_batch_idx=list(range(30)), plot_start_time=10
    )
    print(f"Saved 3D plot to: {save_path_3d}")

print("\n--- Analysis complete ---")

print(f"\n--- Finished configuration: {config_to_run} ---")