In [1]:
import sys
import warnings
import numpy as np
import torch
from scipy import linalg
from sklearn.metrics.pairwise import cosine_similarity
import gymnasium as gym

# ==========================================
# 1. COMPATIBILITY & CONFIGURATION
# ==========================================
# Fix for older SB3 versions referencing 'gym' instead of 'gymnasium'
sys.modules["gym"] = gym 
warnings.filterwarnings("ignore")

try:
    from huggingface_sb3 import load_from_hub
    from stable_baselines3 import DQN, PPO
    from stable_baselines3.common.evaluation import evaluate_policy
except ImportError:
    print("Please install requirements: pip install stable-baselines3 huggingface-sb3 shimmy")

In [2]:
# ==========================================
# 2. LQR ORACLE (Fixed Physics)
# ==========================================
class LQRController:
    """
    Computes the optimal Linear Quadratic Regulator control for CartPole.
    Used as a 'ground truth' to cluster states by their optimal action.
    """
    def __init__(self, env):
        core = env.unwrapped
        # Extract physics constants
        self.g = core.gravity
        self.lp = core.length
        self.mp = core.masspole
        self.mt = core.total_mass
        self.K = self._calculate_gain()
        
    def _calculate_gain(self):
        # Linearize dynamics around the upright fixed point
        # Denominator accounts for the mass ratio effect on angular acceleration
        denom = self.lp * (4.0/3 - self.mp / self.mt)
        a = self.g / denom
        b = -1 / denom
        
        # State Matrix A: [x, x_dot, theta, theta_dot] -> derivatives
        A = np.array([
            [0, 1, 0, 0], 
            [0, 0, 0, 0], 
            [0, 0, 0, 1], 
            [0, 0, a, 0]
        ])
        
        # Control Matrix B: Force -> [x_accel, theta_accel]
        B = np.array([[0], [1/self.mt], [0], [b/self.mt]])
        
        # LQR Cost Matrices
        Q = 5 * np.eye(4) # Penalty for state deviation
        R = np.eye(1)     # Penalty for actuation effort
        
        # Solve Algebraic Riccati Equation
        P = linalg.solve_continuous_are(A, B, Q, R)
        
        # Compute Gain K = R^-1 B^T P
        return np.dot(np.linalg.inv(R), np.dot(B.T, P))
        
    def get_optimal_action(self, state):
        # u = -Kx
        u = -np.dot(self.K, state)[0]
        return 1 if u > 0 else 0

In [3]:
# ==========================================
# 3. LATENT EXTRACTION & SAMPLING
# ==========================================
def extract_representation(model, obs, target_net="default"):
    """
    Hooks into specific PyTorch layers to extract latent activations.
    target_net: 'dqn_q', 'ppo_actor', or 'ppo_critic'
    """
    obs_tensor = torch.as_tensor(obs).float().unsqueeze(0).to(model.device)
    
    # 1. Select the appropriate network
    if target_net == "dqn_q":
        # DQN Q-Network: (0) Linear -> (1) ReLU
        network = model.policy.q_net.q_net
    elif target_net == "ppo_actor":
        # PPO Policy Net: (0) Linear -> (1) Tanh
        network = model.policy.mlp_extractor.policy_net
    elif target_net == "ppo_critic":
        # PPO Value Net: (0) Linear -> (1) Tanh
        network = model.policy.mlp_extractor.value_net
    else:
        raise ValueError(f"Unknown target_net: {target_net}")

    # 2. Forward pass through the selected network
    with torch.no_grad():
        out = network[2](network[1](network[0](obs_tensor)))
    
    return out.cpu().numpy().flatten()

def collect_states(env, model, mode, n_samples):
    """
    Generates a dataset of observations based on the selected mode.
    """
    states = []
    obs, _ = env.reset()
    
    for _ in range(n_samples):
        if mode == "RANDOM":
            # Uniformly sample the state space
            current_obs = env.observation_space.sample()
            states.append(current_obs)
            # Reset occasionally to ensure coverage, though independent sampling doesn't strictly need it
            if _ % 50 == 0: obs, _ = env.reset()
            
        elif mode == "ON_POLICY":
            # Follow the model's trajectory
            states.append(obs)
            action, _ = model.predict(obs, deterministic=True)
            obs, _, term, trunc, _ = env.step(action)
            if term or trunc:
                obs, _ = env.reset()
                
    return np.array(states)

In [4]:
# ==========================================
# 4. METRIC CALCULATION
# ==========================================
def compute_cosine(v1, v2, mean=None):
    """Compute cosine similarity, optionally centering vectors."""
    v1 = v1.reshape(1, -1)
    v2 = v2.reshape(1, -1)
    
    if mean is not None:
        mean = mean.reshape(1, -1)
        v1 = v1 - mean
        v2 = v2 - mean
        
    return cosine_similarity(v1, v2)[0][0]

def compute_cluster_sim(vectors, mean=None):
    """Average pairwise similarity within a cluster of vectors."""
    if len(vectors) < 2: return 0.0
    mat = np.vstack(vectors)
    if mean is not None:
        mat = mat - mean
    sim_matrix = cosine_similarity(mat)
    # Average of upper triangle (excluding diagonal)
    return np.mean(sim_matrix[np.triu_indices(len(mat), k=1)])

def evaluate_component(model, name, target_net, env, oracle):
    print(f"\n{'='*30}\n ANALYZING: {name}\n{'='*30}")
    
    # 1. Policy Performance Check
    mean_rew, std_rew = evaluate_policy(model, env, n_eval_episodes=5, deterministic=True)
    print(f"Policy Return: {mean_rew:.1f} +/- {std_rew:.1f}")
    
    # 2. Data Collection
    print(f"Sampling {CONFIG['N_SAMPLES']} states ({CONFIG['SAMPLING_MODE']})...")
    states = collect_states(env, model, CONFIG["SAMPLING_MODE"], CONFIG["N_SAMPLES"])
    
    # 3. Compute Latents & Global Mean
    latents = [extract_representation(model, s, target_net) for s in states]
    global_mean = np.mean(latents, axis=0)
    
    # 4. Metric Loop
    metrics = {"sym_raw": [], "sym_cent": [], "glob_raw": [], "glob_cent": []}
    cluster_left, cluster_right = [], []
    
    for obs, vec in zip(states, latents):
        # A. Mirror Symmetry: Sim(f(s), f(-s))
        vec_mirror = extract_representation(model, -obs, target_net)
        metrics["sym_raw"].append(compute_cosine(vec, vec_mirror))
        metrics["sym_cent"].append(compute_cosine(vec, vec_mirror, global_mean))
        
        # B. Global Baseline: Sim(f(s), f(random))
        # Compare current vector against a random vector from our pool
        vec_rand = latents[np.random.randint(len(latents))]
        metrics["glob_raw"].append(compute_cosine(vec, vec_rand))
        metrics["glob_cent"].append(compute_cosine(vec, vec_rand, global_mean))
        
        # C. Action Clustering (LQR Oracle)
        optimal_action = oracle.get_optimal_action(obs)
        if optimal_action == 0:
            cluster_left.append(vec)
        else:
            cluster_right.append(vec)
            
    # 5. Aggregate Results
    c_raw = (compute_cluster_sim(cluster_left) + compute_cluster_sim(cluster_right)) / 2
    c_cent = (compute_cluster_sim(cluster_left, global_mean) + compute_cluster_sim(cluster_right, global_mean)) / 2
    
    # 6. Formatting
    print(f"\n{'-'*65}")
    print(f"{'Metric':<25} | {'Uncentered':<15} | {'Centered':<15}")
    print(f"{'-'*65}")
    print(f"{'Global Baseline':<25} | {np.mean(metrics['glob_raw']):<15.4f} | {np.mean(metrics['glob_cent']):<15.4f}")
    print(f"{'Mirror Symmetry':<25} | {np.mean(metrics['sym_raw']):<15.4f} | {np.mean(metrics['sym_cent']):<15.4f}")
    print(f"{'LQR Action Cluster':<25} | {c_raw:<15.4f} | {c_cent:<15.4f}")
    print(f"{'-'*65}")

In [5]:
# ==========================================
# 5. MAIN EXECUTION
# ==========================================

# --- USER CONFIGURATION ---
CONFIG = {
    "SAMPLING_MODE": "ON_POLICY",  # Options: "ON_POLICY" or "RANDOM"
    "N_SAMPLES": 1000,             # Number of states to analyze
    "ENV_ID": "CartPole-v1",
    "REPOS": {
        "DQN": ("sb3/dqn-CartPole-v1", "dqn-CartPole-v1.zip"),
        "PPO": ("sb3/ppo-CartPole-v1", "ppo-CartPole-v1.zip")
    }
}

env = gym.make(CONFIG["ENV_ID"], render_mode="rgb_array")
oracle = LQRController(env)

def load_hf_model(cls, config_key):
    repo, filename = CONFIG["REPOS"][config_key]
    path = load_from_hub(repo_id=repo, filename=filename)
    return cls.load(
        path, 
        env=env, 
        custom_objects={
            "observation_space": env.observation_space, 
            "action_space": env.action_space
        }
    )

try:
    # 1. DQN Analysis
    dqn_model = load_hf_model(DQN, "DQN")
    evaluate_component(dqn_model, "DQN (Q-Network)", "dqn_q", env, oracle)
    
    # 2. PPO Analysis
    ppo_model = load_hf_model(PPO, "PPO")
    evaluate_component(ppo_model, "PPO (Actor/Policy)", "ppo_actor", env, oracle)
    evaluate_component(ppo_model, "PPO (Critic/Value)", "ppo_critic", env, oracle)
    
except Exception as e:
    print(f"An error occurred: {e}")
finally:
    env.close()

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

 ANALYZING: DQN (Q-Network)
Policy Return: 500.0 +/- 0.0
Sampling 1000 states (ON_POLICY)...

-----------------------------------------------------------------
Metric                    | Uncentered      | Centered       
-----------------------------------------------------------------
Global Baseline           | 0.9865          | 0.0531         
Mirror Symmetry           | 0.9677          | -0.5948        
LQR Action Cluster        | 0.9936          | 0.5730         
-----------------------------------------------------------------
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

 ANALYZING: PPO (Actor/Policy)
Policy Return: 500.0 +/- 0.0
Sampling 1000 states (ON_POLICY)...

-----------------------------------------------------------------
Metric                    | Uncentered      | Centered       
-----------------------------------------------------------------
Global Baseli

In [6]:
# ==========================================
# 5. MAIN EXECUTION
# ==========================================

# --- USER CONFIGURATION ---
CONFIG = {
    "SAMPLING_MODE": "RANDOM",  # Options: "ON_POLICY" or "RANDOM"
    "N_SAMPLES": 1000,             # Number of states to analyze
    "ENV_ID": "CartPole-v1",
    "REPOS": {
        "DQN": ("sb3/dqn-CartPole-v1", "dqn-CartPole-v1.zip"),
        "PPO": ("sb3/ppo-CartPole-v1", "ppo-CartPole-v1.zip")
    }
}

env = gym.make(CONFIG["ENV_ID"], render_mode="rgb_array")
oracle = LQRController(env)

def load_hf_model(cls, config_key):
    repo, filename = CONFIG["REPOS"][config_key]
    path = load_from_hub(repo_id=repo, filename=filename)
    return cls.load(
        path, 
        env=env, 
        custom_objects={
            "observation_space": env.observation_space, 
            "action_space": env.action_space
        }
    )

try:
    # 1. DQN Analysis
    dqn_model = load_hf_model(DQN, "DQN")
    evaluate_component(dqn_model, "DQN (Q-Network)", "dqn_q", env, oracle)
    
    # 2. PPO Analysis
    ppo_model = load_hf_model(PPO, "PPO")
    evaluate_component(ppo_model, "PPO (Actor/Policy)", "ppo_actor", env, oracle)
    evaluate_component(ppo_model, "PPO (Critic/Value)", "ppo_critic", env, oracle)
    
except Exception as e:
    print(f"An error occurred: {e}")
finally:
    env.close()

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

 ANALYZING: DQN (Q-Network)
Policy Return: 500.0 +/- 0.0
Sampling 1000 states (RANDOM)...

-----------------------------------------------------------------
Metric                    | Uncentered      | Centered       
-----------------------------------------------------------------
Global Baseline           | 0.3737          | 0.0613         
Mirror Symmetry           | 0.3687          | -0.1906        
LQR Action Cluster        | 0.4855          | 0.1831         
-----------------------------------------------------------------
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

 ANALYZING: PPO (Actor/Policy)
Policy Return: 500.0 +/- 0.0
Sampling 1000 states (RANDOM)...

-----------------------------------------------------------------
Metric                    | Uncentered      | Centered       
-----------------------------------------------------------------
Global Baseline    