In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import json
# -------------------------
# 1. 环境定义
# -------------------------
def load_concatenated_json(file_path):
    """
    Reads a file containing one or more concatenated JSON objects
    and merges them into a single, valid data structure.
    """
    decoder = json.JSONDecoder()
    data_parts = []
    
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
        idx = 0
        while idx < len(content):
            while idx < len(content) and content[idx].isspace():
                idx += 1
            if idx == len(content):
                break
            try:
                obj, end = decoder.raw_decode(content, idx)
                data_parts.append(obj)
                idx = end
            except json.JSONDecodeError:
                break

    if not data_parts:
        return None

    # Merge the parts
    final_data = {k: v for k, v in data_parts[0].items() if k != 'traces'}
    all_traces = [trace for part in data_parts for trace in part.get('traces', [])]
    
    final_data['traces'] = all_traces
    final_data['num_traces'] = len(all_traces)
    
    return final_data
class DeepConfEnv:
    def __init__(self, traces):
        self.traces = traces
        self.reset()

    def reset(self):
        self.idx = np.random.randint(len(self.traces))
        self.trace = self.traces[self.idx]
        self.conf_curve = np.array(self.trace['group_conf'])
        self.correct = self.trace['ground_truth_correct']
        self.t = 0
        return np.array([self.conf_curve[self.t]])

    def step(self, action):
        done = False
        reward = 0.0

        if action == 1 or self.t == len(self.conf_curve) - 1:
            done = True
            # 奖励设计
            if self.correct:
                reward = 1.0
            else:
                reward = -1.0
            reward -= 0.01 * self.t  # 稍微惩罚长trace
        else:
            self.t += 1
            reward = 0.0

        next_state = np.array([self.conf_curve[self.t]]) if not done else np.zeros(1)
        return next_state, reward, done

# -------------------------
# 2. Policy 网络
# -------------------------
class PolicyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU(),
            nn.Linear(32, 2),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.net(x)

# -------------------------
# 3. 训练 RL (REINFORCE)
# -------------------------
def train_rl(traces, num_episodes=1000, gamma=0.99, lr=1e-3):
    env = DeepConfEnv(traces)
    policy = PolicyNet()
    optimizer = optim.Adam(policy.parameters(), lr=lr)

    for episode in tqdm(range(num_episodes)):
        states, actions, rewards = [], [], []
        state = env.reset()
        done = False

        while not done:
            state_t = torch.FloatTensor(state).unsqueeze(0)
            probs = policy(state_t)
            dist = torch.distributions.Categorical(probs)
            action = dist.sample()

            next_state, reward, done = env.step(action.item())

            states.append(state_t)
            actions.append(action)
            rewards.append(reward)

            state = next_state

        # 计算回报
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)
        returns = torch.FloatTensor(returns)

        # policy gradient 更新
        loss = 0
        for s, a, R in zip(states, actions, returns):
            probs = policy(s)
            log_p = torch.log(probs[0, a])
            loss -= log_p * R

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return policy

# -------------------------
# 4. 评估
# -------------------------
def evaluate_policy(policy, traces, num_samples=200):
    y_true, y_pred = [], []
    stop_positions = []

    for i in range(min(num_samples, len(traces))):
        trace = traces[i]
        conf_curve = np.array(trace['group_conf'])
        correct = trace['ground_truth_correct']

        t = 0
        while t < len(conf_curve) - 1:
            s = torch.FloatTensor([[conf_curve[t]]])
            probs = policy(s)
            action = torch.argmax(probs).item()
            if action == 1:
                break
            t += 1

        y_true.append(int(correct))
        y_pred.append(int(correct))  # 暂时假设 early stop 不影响 correctness
        stop_positions.append(t)

    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(cm, display_labels=["Wrong", "Correct"])
    disp.plot(cmap="Blues")
    plt.title("RL Policy Early Stop Confusion Matrix")
    plt.show()

    return stop_positions

# -------------------------
# 5. 可视化若干条 Trace
# -------------------------
def plot_sample_traces(traces, policy, num_examples=5):
    plt.figure(figsize=(8, 5))
    for i in range(num_examples):
        trace = traces[i]
        conf_curve = np.array(trace['group_conf'])
        correct = trace['ground_truth_correct']

        t = 0
        while t < len(conf_curve) - 1:
            s = torch.FloatTensor([[conf_curve[t]]])
            probs = policy(s)
            action = torch.argmax(probs).item()
            if action == 1:
                break
            t += 1

        color = "green" if correct else "orange"
        plt.plot(conf_curve, color=color, alpha=0.6)
        plt.axvline(t, color="red", linestyle="--", alpha=0.7)

    plt.xlabel("Step")
    plt.ylabel("Group Confidence")
    plt.title("RL Policy Early Stop Visualization")
    plt.show()

# -------------------------
# 6. 主流程
# -------------------------
data = load_concatenated_json('trace_data/aime_2025_0_full.jsonl')
traces = data['traces']

policy = train_rl(traces, num_episodes=500)



In [None]:
stop_positions = evaluate_policy(policy, traces)
plot_sample_traces(traces, policy)