In [19]:
import numpy as np
import gymnasium as gym
# import gym
import sys
sys.modules['gym'] = gym
import ale_py

In [20]:
from typing import Optional, Dict, Any, List, Callable

In [21]:
# print("ALE version:", ale_py.__version__)
# gym.register_envs(ale_py)

In [22]:
class PongSymmetryAnalyzer:
    """
    A class for analyzing symmetries in Pong environments using RAM observations.
    """
    
    def __init__(self, render_mode: Optional[str] = None):
        """
        Initialize the Pong environment with RAM observations.
        
        Args:
            render_mode: Rendering mode for the environment (None, "human", etc.)
        """
        # Register ALE environments
#         gym.register_envs(ale_py)
        
#         # Create environment with RAM observations
#         self.env = gym.make(
#             "ALE/Pong-v5",
#             obs_type="ram",
#             render_mode=render_mode,
#         )
        
        # RAM indices for extracting game state
        self.PONG_RAM_INDEX = {
            "ball_x": 49,
            "ball_y": 54,
            "enemy_y": 50,
            "player_y": 51,
        }
        
        # Game boundaries
        self.BALL_X_MIN = 50
        self.BALL_X_MAX = 208
        self.BALL_Y_MIN = 44
        self.BALL_Y_MAX = 207
        self.PLAYER_Y_MIN = 38
        self.PLAYER_Y_MAX = 203
        self.ENEMY_Y_MIN = 0
        self.ENEMY_Y_MAX = 208
        
        self.BALL_X_MID = (self.BALL_X_MAX - self.BALL_X_MIN) / 2
        self.BALL_Y_MID = (self.BALL_Y_MAX - self.BALL_Y_MIN) / 2

        self.PLAYER_Y_MID = (self.PLAYER_Y_MAX - self.PLAYER_Y_MIN) / 2
        self.ENEMY_Y_MID = (self.ENEMY_Y_MAX - self.ENEMY_Y_MIN) / 2
        
        # Store sampled states
        self.sampled_states: List[Dict[str, Any]] = []
        
    def ram_to_logic_state(
        self, 
        ram: np.ndarray, 
        prev_state: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """
        Convert RAM observation to logical game state.
        
        Args:
            ram: 128-byte RAM vector from the environment
            prev_state: Previous logical state for velocity calculation
            
        Returns:
            Dictionary containing ball position, velocity, and paddle positions
        """
        # Extract raw positions from RAM
        ball_x_raw = int(ram[self.PONG_RAM_INDEX["ball_x"]])
        ball_y_raw = int(ram[self.PONG_RAM_INDEX["ball_y"]])
        enemy_y_raw = int(ram[self.PONG_RAM_INDEX["enemy_y"]])
        player_y_raw = int(ram[self.PONG_RAM_INDEX["player_y"]])
        
        # Check if ball exists (OCAtari condition)
        ball_exists = (ball_y_raw != 0) and (ball_x_raw > 49)
        
        # Set ball position (None if ball doesn't exist)
        ball_x = ball_x_raw if ball_exists else None
        ball_y = ball_y_raw if ball_exists else None
        
        # Paddle positions
        player_y = player_y_raw
        enemy_y = enemy_y_raw
        
        # Calculate velocity via finite differences
        if (prev_state is not None and 
            prev_state.get("ball_x") is not None and 
            ball_x is not None):
            ball_dx = ball_x - prev_state["ball_x"]
            ball_dy = ball_y - prev_state["ball_y"]
        else:
            ball_dx = 0
            ball_dy = 0
            
        return {
            "ball_x": ball_x,
            "ball_y": ball_y,
            "ball_dx": ball_dx,
            "ball_dy": ball_dy,
            "player_y": player_y,
            "enemy_y": enemy_y,
        }
    
    def sample_states(
        self, 
        num_episodes: int = 5, 
        max_steps_per_episode: int = 1000,
        model: Optional[Any] = None,
        max_states: Optional[int] = None
    ) -> List[Dict[str, Any]]:
        """
        Sample states from the environment using either a model or random policy.
        
        Args:
            num_episodes: Number of episodes to run
            max_steps_per_episode: Maximum steps per episode
            model: Optional model with select_action method. If None, uses random policy
            max_states: Maximum number of states to collect. If None, no limit
            
        Returns:
            List of logical game states
        """
        self.sampled_states = []
        
        for ep in range(num_episodes):
            ram, info = self.env.reset()
            prev_state = None
            
            for t in range(max_steps_per_episode):
                # Convert RAM to logical state
                state = self.ram_to_logic_state(ram, prev_state=prev_state)
                self.sampled_states.append(state)
                prev_state = state
                
                # Check if we've reached the maximum number of states
                if max_states is not None and len(self.sampled_states) >= max_states:
                    return self.sampled_states
                
                # Select action using model or random policy
                if model is not None:
                    # Assume model has a select_action method
                    if hasattr(model, 'select_action'):
                        action = model.select_action(ram)
                    elif hasattr(model, 'predict'):
                        action = model.predict(ram)
                    else:
                        # Try calling the model directly
                        action = model(ram)
                else:
                    # Random policy
                    action = self.env.action_space.sample()
                
                # Take step in environment
                ram, reward, terminated, truncated, info = self.env.step(action)
                done = terminated or truncated
                
                if done:
                    break
                    
        return self.sampled_states
    
    def naive_symmetry(self, state_1: Dict[str, Any], state_2: Dict[str, Any]) -> bool:
        """
        Check if two states are symmetric under naive symmetry assumption.
        Ignores scores and opponent paddle position, bins coordinates.
        
        Args:
            state_1: First game state
            state_2: Second game state
            
        Returns:
            True if states are considered symmetric, False otherwise
        """
        # Check if ball x positions and x velocities match
        ballx = state_1["ball_x"] == state_2["ball_x"]
        balldx = state_1["ball_dx"] == state_2["ball_dx"]
        
        if not ballx or not balldx:
            return False
            
        bally = state_1["ball_y"] == state_2["ball_y"]
        balldy = state_1["ball_dy"] == state_2["ball_dy"]
        playery = state_1["player_y"] == state_2["player_y"]
        
        # Case 1: Completely equal
        if bally and balldy and playery:
            return True
        
        # Case 2: Ball off screen, paddle positions symmetric
        if state_1["ball_y"] is None and state_2["ball_y"] is None:
            if (abs(state_1["player_y"] - self.BALL_Y_MID) == 
                abs(state_2["player_y"] - self.BALL_Y_MID)):
                return True
        
        # Case 3: Symmetric reflection about y-axis
        if (state_1["ball_y"] is not None and state_2["ball_y"] is not None):
            ball_y_symmetric = (abs(state_1["ball_y"] - self.BALL_Y_MID) == 
                              abs(state_2["ball_y"] - self.BALL_Y_MID))
            player_y_symmetric = (abs(state_1["player_y"] - self.PLAYER_Y_MID) == 
                                abs(state_2["player_y"] - self.PLAYER_Y_MID))
            ball_dy_opposite = state_1["ball_dy"] == -state_2["ball_dy"]
            
            if ball_y_symmetric and player_y_symmetric and ball_dy_opposite:
                return True
                
        return False
    
    def generate_similarity_matrix(
        self, 
        states: Optional[List[Dict[str, Any]]] = None,
        symmetry_function: Optional[Callable] = None
    ) -> np.ndarray:
        """
        Generate a similarity matrix for the given states using a symmetry function.
        
        Args:
            states: List of states to compare. If None, uses self.sampled_states
            symmetry_function: Function to check symmetry between two states.
                             If None, uses self.naive_symmetry
                             
        Returns:
            Binary similarity matrix where entry (i,j) is 1 if states i and j are symmetric
        """
        if states is None:
            states = self.sampled_states
            
        if symmetry_function is None:
            symmetry_function = self.naive_symmetry
            
        if len(states) == 0:
            raise ValueError("No states provided for similarity matrix generation")
            
        matrix = np.zeros((len(states), len(states)), dtype=int)
        
        for i, state_i in enumerate(states):
            for j, state_j in enumerate(states):
                matrix[i][j] = int(symmetry_function(state_i, state_j))
                
        return matrix
    
    def get_similarity_stats(self, similarity_matrix: np.ndarray) -> Dict[str, float]:
        """
        Compute statistics about the similarity matrix.
        
        Args:
            similarity_matrix: Binary similarity matrix
            
        Returns:
            Dictionary with statistics about symmetries
        """
        n = similarity_matrix.shape[0]
        total_pairs = n * (n - 1) // 2  # Exclude diagonal
        
        # Count symmetric pairs (excluding diagonal)
        symmetric_pairs = (np.sum(similarity_matrix) - n) // 2
        
        return {
            "total_states": n,
            "total_pairs": total_pairs,
            "symmetric_pairs": symmetric_pairs,
            "symmetry_ratio": symmetric_pairs / total_pairs if total_pairs > 0 else 0.0,
            "diagonal_sum": np.sum(np.diag(similarity_matrix)),
        }
    
    def close(self):
        """Close the environment."""
        # self.env.close()
        pass


# # Example usage
# if __name__ == "__main__":
#     # Initialize analyzer
#     analyzer = PongSymmetryAnalyzer()
    
#     # Sample states using random policy
#     states = analyzer.sample_states(
#         num_episodes=3, 
#         max_steps_per_episode=500,
#         max_states=100
#     )
    
#     print(f"Sampled {len(states)} states")
    
#     # Generate similarity matrix
#     similarity_matrix = analyzer.generate_similarity_matrix()
    
#     # Get statistics
#     stats = analyzer.get_similarity_stats(similarity_matrix)
#     print("Similarity Statistics:")
#     for key, value in stats.items():
#         print(f"  {key}: {value}")
    
#     # Close environment
#     analyzer.close()

In [23]:
def naive_symmetry(state_1, state_2):
    """
    In this naive treatment, I ignore: scores, opponent paddle position.
    I bin the x, y coordinates of the ball.
    I bin the y coordinate of the player paddle.
    Function returns True if: 
        - coarse x is equal in both states
        - dx is equal in both states
        - AND ball y and paddle y are equal OR equal on reflection in y=0 line
    """
    ballx = state_1["ball_x"] == state_2["ball_x"]
    balldx = state_1["ball_dx"] == state_2["ball_dx"]

    if not ballx or not balldx:
        return False
        
    bally = state_1["ball_y"] == state_2["ball_y"]
    balldy = state_1["ball_dy"] == state_2["ball_dy"] 

    playery = state_1["player_y"] == state_2["player_y"] 

    # case 1: completely equal
    if bally:
        if balldy:
            if playery:
                return True

    # case 2: ball off screen. paddle flipped
    if state_1["ball_y"] is None:
        if abs(state_1["player_y"] - BALL_Y_MID) == abs(state_2["player_y"] - BALL_Y_MID):
            return True
        
    # case 2: flipped
    if state_1["ball_y"] is not None:
        if abs(state_1["ball_y"] - BALL_Y_MID) == abs(state_2["ball_y"] - BALL_Y_MID):
            if abs(state_1["player_y"] - BALL_Y_MID) == abs(state_2["player_y"] - BALL_Y_MID):
                if state_1["ball_dy"] == -state_2["ball_dy"]:
                    return True

    return False

In [24]:
def get_similarity_matrix(states):
    matrix = np.zeros((len(states), len(states)))
    for i, state_i in enumerate(states):
        for j, state_j in enumerate(states):
            matrix[i][j] = naive_symmetry(state_i, state_j)
    return matrix

In [25]:
PONG_RAM_INDEX = {
    "ball_x":   49,
    "ball_y":   54,
    "enemy_y":  50,
    "player_y": 51,
}

BALL_X_MIN = 50
BALL_X_MAX = 208
BALL_Y_MIN = 44
BALL_Y_MAX = 207
PLAYER_Y_MIN = 38
PLAYER_Y_MAX = 203
ENEMY_Y_MIN = 0
ENEMY_Y_MAX = 208

BALL_X_MID = (BALL_X_MAX - BALL_X_MIN) / 2
BALL_Y_MID = (BALL_Y_MAX - BALL_Y_MIN) / 2

PLAYER_Y_MID = (PLAYER_Y_MAX - PLAYER_Y_MIN) / 2
ENEMY_Y_MID = (ENEMY_Y_MAX - ENEMY_Y_MIN) / 2

In [26]:
def ram_to_logic_state(
    ram: np.ndarray,
    prev_state: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    """
    Minimal logic-level state for ALE Pong from a 128-byte RAM vector.

    Returns:
        {
            "ball_x":   int or None,
            "ball_y":   int or None,
            "ball_dx":  int,
            "ball_dy":  int,
            "player_y": int or None,
            "enemy_y":  int or None,
        }
    """
    # Raw positions from RAM (as in OCAtari)
    ball_x_raw   = int(ram[PONG_RAM_INDEX["ball_x"]])
    ball_y_raw   = int(ram[PONG_RAM_INDEX["ball_y"]])
    enemy_y_raw  = int(ram[PONG_RAM_INDEX["enemy_y"]])
    player_y_raw = int(ram[PONG_RAM_INDEX["player_y"]])

    # OCAtari condition for “ball exists”
    ball_exists = (ball_y_raw != 0) and (ball_x_raw > 49)

    # If you want to treat "no ball" explicitly:
    ball_x = ball_x_raw if ball_exists else None
    ball_y = ball_y_raw if ball_exists else None

    # Paddles basically always exist when game is running;
    # if you want to mirror OCAtari, you could add checks similar to ram[50] / ram[51] ranges.
    player_y = player_y_raw
    enemy_y  = enemy_y_raw

    # Velocity via finite differences
    if prev_state is not None and prev_state.get("ball_x") is not None and ball_x is not None:
        ball_dx = ball_x - prev_state["ball_x"]
        ball_dy = ball_y - prev_state["ball_y"]
    else:
        ball_dx = 0
        ball_dy = 0

    return {
        "ball_x":   ball_x,
        "ball_y":   ball_y,
        "ball_dx":  ball_dx,
        "ball_dy":  ball_dy,
        "player_y": player_y,
        "enemy_y":  enemy_y,
    }

In [27]:
def get_stats():
    ball_x_vals = []
    ball_y_vals = []
    
    player_y_vals = []
    enemy_y_vals = []
    
    num_episodes = 5000          # tweak as you like
    max_steps_per_ep = 10000    # to cap runtime
    
    
    def update_stats_from_state(state):
        if state["ball_x"] is not None and state["ball_y"] is not None:
            ball_x_vals.append(state["ball_x"])
            ball_y_vals.append(state["ball_y"])
        player_y_vals.append(state["player_y"])
        enemy_y_vals.append(state["enemy_y"])
    
    
    # --- Run random rollouts and collect stats ---
    
    for ep in range(num_episodes):
        ram, info = env.reset()
        prev_state = None
    
        for t in range(max_steps_per_ep):
            state = ram_to_logic_state(ram, prev_state=prev_state)
            update_stats_from_state(state)
            prev_state = state
    
            action = env.action_space.sample()
            step_out = env.step(action)
    
            ram, reward, terminated, truncated, info = step_out
            done = terminated or truncated
            # ram, reward, done, info = step_out
    
            if done:
                break
    
    env.close()
    
    # --- Compute and print stats ---
    
    def summarize(name, values):
        if len(values) == 0:
            print(f"{name}: no data collected")
            return
        v = np.array(values)
        print(
            f"{name}: min={v.min()}, max={v.max()}, mean={v.mean():.2f}, "
            f"unique={len(np.unique(v))}"
        )
    summarize("ball_x", ball_x_vals)
    summarize("ball_y", ball_y_vals)
    summarize("player_y", player_y_vals)
    summarize("enemy_y", enemy_y_vals)

# RSA Analysis with Hugging Face Models

This section downloads trained models from Hugging Face and performs Representational Similarity Analysis (RSA) on the activations to compare with our behavioral similarity matrix.

In [28]:
# Install required packages for Hugging Face models
# !pip install stable-baselines3[extra] huggingface_sb3 torch torchvision

import torch
import torch.nn as nn
from stable_baselines3 import PPO, DQN, A2C
# import gym as gym_old
from huggingface_sb3 import load_from_hub
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import pdist, squareform
import cv2
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

In [29]:
class ModelActivationExtractor:
    """
    Extract activations from different layers of trained RL models.
    """
    
    def __init__(self, model, model_type='ppo'):
        self.model = model
        self.model_type = model_type.lower()
        self.activations = {}
        self.hooks = []
        
    def register_hooks(self, layer_names=None):
        """Register forward hooks to capture activations."""
        def hook_fn(name):
            def hook(module, input, output):
                if isinstance(output, torch.Tensor):
                    self.activations[name] = output.detach().cpu().numpy()
                elif isinstance(output, (list, tuple)):
                    # Handle cases where output is a tuple/list
                    self.activations[name] = output[0].detach().cpu().numpy()
            return hook
            
        # Get the policy network
        if hasattr(self.model, 'policy'):
            policy_net = self.model.policy
        else:
            policy_net = self.model
            
        # Register hooks for different model architectures
        if hasattr(policy_net, 'features_extractor'):
            # CNN-based models
            cnn = policy_net.features_extractor.cnn
            for i, layer in enumerate(cnn):
                if isinstance(layer, (nn.Conv2d, nn.Linear, nn.ReLU)):
                    name = f'cnn_layer_{i}_{layer.__class__.__name__}'
                    if layer_names is None or name in layer_names:
                        hook = layer.register_forward_hook(hook_fn(name))
                        self.hooks.append(hook)
                        
        # Value and policy heads
        if hasattr(policy_net, 'value_net'):
            for i, layer in enumerate(policy_net.value_net):
                if isinstance(layer, (nn.Linear, nn.ReLU)):
                    name = f'value_net_{i}_{layer.__class__.__name__}'
                    if layer_names is None or name in layer_names:
                        hook = layer.register_forward_hook(hook_fn(name))
                        self.hooks.append(hook)
                        
        if hasattr(policy_net, 'action_net'):
            for i, layer in enumerate(policy_net.action_net):
                if isinstance(layer, (nn.Linear, nn.ReLU)):
                    name = f'action_net_{i}_{layer.__class__.__name__}'
                    if layer_names is None or name in layer_names:
                        hook = layer.register_forward_hook(hook_fn(name))
                        self.hooks.append(hook)
    
    def extract_activations(self, states):
        """Extract activations for a batch of states."""
        self.activations.clear()
        
        # Convert states to tensor
        if isinstance(states, np.ndarray):
            states_tensor = torch.FloatTensor(states)
        else:
            states_tensor = torch.FloatTensor(np.array(states))
            
        # Ensure correct shape (batch_size, channels, height, width)
        if len(states_tensor.shape) == 3:
            states_tensor = states_tensor.unsqueeze(0)
        if states_tensor.shape[-1] == 3:  # If channels last
            states_tensor = states_tensor.permute(0, 3, 1, 2)
            
        # Normalize to [0, 1] if needed
        if states_tensor.max() > 1.0:
            states_tensor = states_tensor / 255.0
            
        with torch.no_grad():
            if self.model_type in ['ppo', 'a2c']:
                # For policy gradient methods
                self.model.policy(states_tensor)
            elif self.model_type == 'dqn':
                # For Q-learning methods
                self.model.q_net(states_tensor)
                
        return dict(self.activations)
    
    def cleanup(self):
        """Remove all registered hooks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

In [44]:
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage

def download_models(env_name: str = "PongNoFrameskip-v4"):
    """Download pre-trained models from Hugging Face."""
    models = {}

    # pre-create environment
    # env = gym.make(env_name, render_mode="human")
    # env = ChannelFirstWrapper(env)
    env = make_atari_env(env_name, n_envs=1, env_kwargs={"render_mode": "human"})
    env = VecFrameStack(env, n_stack=4)
    env = VecTransposeImage(env)

    # 4. Load with CUSTOM_OBJECTS override
    # This replaces the metadata from the file with the spaces from your current env
    custom_objects = {
        "observation_space": env.observation_space,
        "action_space": env.action_space
    }
    
    try:
        # PPO model
        print("Downloading PPO model...")
        ppo_path = load_from_hub(
            repo_id=f"sb3/ppo-{env_name}",
            filename=f"ppo-{env_name}.zip"
        )
        ppo_model = PPO.load(ppo_path, env=env, custom_objects=custom_objects)
        models['ppo'] = ppo_model
        print("PPO model downloaded successfully")
    except Exception as e:
        print(f"Failed to download PPO model: {e}")
        
    try:
        # DQN model
        print("Downloading DQN model...")
        dqn_path = load_from_hub(
            repo_id=f"sb3/dqn-{env_name}", 
            filename=f"dqn-{env_name}.zip"
        )
        dqn_model = DQN.load(dqn_path, env=env, custom_objects=custom_objects)
        models['dqn'] = dqn_model
        print("DQN model downloaded successfully")
    except Exception as e:
        print(f"Failed to download DQN model: {e}")
        
    try:
        # A2C model
        print("Downloading A2C model...")
        a2c_path = load_from_hub(
            repo_id=f"sb3/a2c-{env_name}",
            filename=f"a2c-{env_name}.zip"
        )
        a2c_model = A2C.load(a2c_path, env=env, custom_objects=custom_objects)
        models['a2c'] = a2c_model
        print("A2C model downloaded successfully")
    except Exception as e:
        print(f"Failed to download A2C model: {e}")
        
    return models

In [45]:
def preprocess_state_for_model(state_rgb, target_shape=(84, 84)):
    """
    Preprocess RGB state to match model input requirements.
    Atari models typically expect 84x84 grayscale or stacked frames.
    """
    # Convert to grayscale
    if len(state_rgb.shape) == 3:
        gray = cv2.cvtColor(state_rgb, cv2.COLOR_RGB2GRAY)
    else:
        gray = state_rgb
        
    # Resize to target shape
    resized = cv2.resize(gray, target_shape)
    
    # Normalize to [0, 1]
    normalized = resized.astype(np.float32) / 255.0
    
    return normalized

def prepare_pixel_states(rgb_states, model_type='atari'):
    """
    Prepare pixel states for model input.
    """
    processed_states = []
    
    for state in rgb_states:
        if model_type == 'atari':
            # For Atari models, we need 4 stacked frames
            processed = preprocess_state_for_model(state)
            # Stack the same frame 4 times (approximation)
            stacked = np.stack([processed] * 4, axis=0)
            processed_states.append(stacked)
        else:
            processed_states.append(state)
            
    return np.array(processed_states)

In [46]:
def compute_rsa_matrix(activations):
    """
    Compute RSA matrix from neural activations.
    """
    # Flatten activations if they're multi-dimensional
    if len(activations.shape) > 2:
        activations_flat = activations.reshape(activations.shape[0], -1)
    else:
        activations_flat = activations
        
    # Compute pairwise correlations
    rsa_matrix = np.corrcoef(activations_flat)
    
    return rsa_matrix

def compare_matrices(behavioral_matrix, neural_matrix, method='pearson'):
    """
    Compare behavioral similarity matrix with neural RSA matrix.
    """
    # Get upper triangular indices (excluding diagonal)
    n = behavioral_matrix.shape[0]
    triu_indices = np.triu_indices(n, k=1)
    
    behavioral_flat = behavioral_matrix[triu_indices]
    neural_flat = neural_matrix[triu_indices]
    
    if method == 'pearson':
        corr, p_val = pearsonr(behavioral_flat, neural_flat)
    elif method == 'spearman':
        corr, p_val = spearmanr(behavioral_flat, neural_flat)
    else:
        raise ValueError("Method must be 'pearson' or 'spearman'")
        
    return corr, p_val

def analyze_similarity_groups(behavioral_matrix, neural_matrix):
    """
    Compare average neural similarity for behaviorally similar vs dissimilar states.
    """
    # Get upper triangular indices
    n = behavioral_matrix.shape[0]
    triu_indices = np.triu_indices(n, k=1)
    
    behavioral_flat = behavioral_matrix[triu_indices]
    neural_flat = neural_matrix[triu_indices]
    
    # Split into similar (1) and dissimilar (0) groups based on behavioral matrix
    similar_mask = behavioral_flat == 1
    dissimilar_mask = behavioral_flat == 0
    
    similar_neural = neural_flat[similar_mask]
    dissimilar_neural = neural_flat[dissimilar_mask]
    
    results = {
        'similar_mean': np.mean(similar_neural) if len(similar_neural) > 0 else np.nan,
        'similar_std': np.std(similar_neural) if len(similar_neural) > 0 else np.nan,
        'dissimilar_mean': np.mean(dissimilar_neural) if len(dissimilar_neural) > 0 else np.nan,
        'dissimilar_std': np.std(dissimilar_neural) if len(dissimilar_neural) > 0 else np.nan,
        'similar_count': len(similar_neural),
        'dissimilar_count': len(dissimilar_neural)
    }
    
    # Compute statistical test
    if len(similar_neural) > 0 and len(dissimilar_neural) > 0:
        from scipy.stats import ttest_ind
        t_stat, p_val = ttest_ind(similar_neural, dissimilar_neural)
        results['t_stat'] = t_stat
        results['p_val'] = p_val
    
    return results

In [47]:
# Download models
print("Downloading models from Hugging Face...")
models = download_models()
print(f"Downloaded {len(models)} models: {list(models.keys())}")

Downloading models from Hugging Face...
Downloading PPO model...
PPO model downloaded successfully
Downloading DQN model...
Failed to download DQN model: ReplayBuffer does not support optimize_memory_usage = True and handle_timeout_termination = True simultaneously.
Downloading A2C model...
A2C model downloaded successfully
Downloaded 2 models: ['ppo', 'a2c']


In [58]:
# First, let's collect some pixel states along with our logical states
print("Collecting pixel states...")

# Initialize a fresh environment for pixel collection
# pixel_env = gym.make(
#     "ALE/Pong-v5",
#     obs_type="rgb",  # Get RGB pixels instead of RAM
#     render_mode=None,
# )

pixel_env = gym.make("PongNoFrameskip-v4")
ale = pixel_env.unwrapped.ale

# Collect states with both pixel and logical representations
pixel_states = []
ram_states = []
logical_states = []

num_episodes = 3
max_steps = 500
max_states_to_collect = 100

for ep in range(num_episodes):
    rgb_obs, _ = pixel_env.reset()
    ram_obs = ale.getRAM()
    
    # Also get RAM version for logical state
    # ram_env = gym.make("ALE/Pong-v5", obs_type="ram", render_mode=None)
    # ram_obs, _ = ram_env.reset()
    
    prev_logical = None
    
    for t in range(max_steps):
        if len(pixel_states) >= max_states_to_collect:
            break
            
        # Store pixel state
        pixel_states.append(rgb_obs)
        
        # Get corresponding logical state
        ram_states.append(ram_obs)
        logical_state = ram_to_logic_state(ram_obs, prev_state=prev_logical)
        logical_states.append(logical_state)
        prev_logical = logical_state
        
        # Take same action in both environments
        action = pixel_env.action_space.sample()
        
        # rgb_obs, _, _, _, _ = pixel_env.step(action)
        # aasdad = pixel_env.step(action)
        # import pdb; pdb.set_trace()
        rgb_obs, reward, terminated, truncated, info = pixel_env.step(action)
        ram_obs = ale.getRAM()
        # ram_obs, _, _, _, _ = ram_env.step(action)
        
        if truncated or terminated:
            break
            
    # ram_env.close()
    
    if len(pixel_states) >= max_states_to_collect:
        break

pixel_env.close()

print(f"Collected {len(pixel_states)} pixel states and {len(logical_states)} logical states")

Collecting pixel states...
Collected 100 pixel states and 100 logical states


In [28]:
import gym

pixel_env = gym.make("PongNoFrameskip-v4")
print("env type:", type(pixel_env))
print("unwrapped type:", type(pixel_env.unwrapped))
print("has ale?", hasattr(pixel_env.unwrapped, "ale"))

env type: <class 'gym.wrappers.time_limit.TimeLimit'>
unwrapped type: <class 'gym.envs.atari.environment.AtariEnv'>
has ale? True


In [59]:
# Compute behavioral similarity matrix for our collected states
print("Computing behavioral similarity matrix...")

# Use a subset of states for computational efficiency
n_states = min(50, len(logical_states))
subset_logical = logical_states[:n_states]
subset_pixels = pixel_states[:n_states]

# Create analyzer for the subset
analyzer_subset = PongSymmetryAnalyzer()
behavioral_matrix = analyzer_subset.generate_similarity_matrix(subset_logical)
analyzer_subset.close()

print(f"Behavioral similarity matrix shape: {behavioral_matrix.shape}")

# Get statistics
stats = analyzer_subset.get_similarity_stats(behavioral_matrix)
print("Behavioral similarity stats:")
for key, value in stats.items():
    print(f"  {key}: {value}")

Computing behavioral similarity matrix...
Behavioral similarity matrix shape: (50, 50)
Behavioral similarity stats:
  total_states: 50
  total_pairs: 1225
  symmetric_pairs: 109
  symmetry_ratio: 0.08897959183673469
  diagonal_sum: 50


In [25]:
logical_states[40]

{'ball_x': 192,
 'ball_y': 192,
 'ball_dx': 0,
 'ball_dy': 0,
 'player_y': 192,
 'enemy_y': 192}

In [None]:
# Perform RSA analysis for each model
rsa_results = {}

for model_name, model in models.items():
    print(f"\n=== Analyzing {model_name.upper()} model ===")
    
    try:
        # Prepare pixel states for the model
        processed_pixels = prepare_pixel_states(subset_pixels, 'atari')
        print(f"Processed pixel states shape: {processed_pixels.shape}")

        import pdb; pdb.set_trace()
        
        # Create activation extractor
        extractor = ModelActivationExtractor(model, model_name)
        extractor.register_hooks()
        
        # Extract activations
        print("Extracting activations...")
        activations = extractor.extract_activations(processed_pixels)
        
        layer_results = {}
        
        for layer_name, layer_activations in activations.items():
            print(f"  Processing {layer_name} (shape: {layer_activations.shape})")
            
            # Compute RSA matrix
            rsa_matrix = compute_rsa_matrix(layer_activations)
            
            # Compare with behavioral matrix
            corr_pearson, p_pearson = compare_matrices(behavioral_matrix, rsa_matrix, 'pearson')
            corr_spearman, p_spearman = compare_matrices(behavioral_matrix, rsa_matrix, 'spearman')
            
            # Analyze similarity groups
            group_analysis = analyze_similarity_groups(behavioral_matrix, rsa_matrix)
            
            layer_results[layer_name] = {
                'rsa_matrix': rsa_matrix,
                'pearson_corr': corr_pearson,
                'pearson_p': p_pearson,
                'spearman_corr': corr_spearman,
                'spearman_p': p_spearman,
                'group_analysis': group_analysis
            }
            
            print(f"    Pearson correlation: {corr_pearson:.4f} (p={p_pearson:.4f})")
            print(f"    Spearman correlation: {corr_spearman:.4f} (p={p_spearman:.4f})")
            print(f"    Similar group mean: {group_analysis['similar_mean']:.4f} ± {group_analysis['similar_std']:.4f}")
            print(f"    Dissimilar group mean: {group_analysis['dissimilar_mean']:.4f} ± {group_analysis['dissimilar_std']:.4f}")
            if 'p_val' in group_analysis:
                print(f"    Group difference p-value: {group_analysis['p_val']:.4f}")
        
        rsa_results[model_name] = layer_results
        extractor.cleanup()
        
    except Exception as e:
        print(f"Error analyzing {model_name}: {e}")
        continue


=== Analyzing PPO model ===
Processed pixel states shape: (50, 4, 84, 84)
> [32m/var/folders/w5/wtxn2_3x6jgbtxczqlsckq2r0000gp/T/ipykernel_89467/2842313008.py[39m([92m15[39m)[36m<module>[39m[34m()[39m
[32m     13[39m 
[32m     14[39m         [38;5;66;03m# Create activation extractor[39;00m
[32m---> 15[39m         extractor = ModelActivationExtractor(model, model_name)
[32m     16[39m         extractor.register_hooks()
[32m     17[39m 



ipdb>  n


> [32m/var/folders/w5/wtxn2_3x6jgbtxczqlsckq2r0000gp/T/ipykernel_89467/2842313008.py[39m([92m16[39m)[36m<module>[39m[34m()[39m
[32m     14[39m         [38;5;66;03m# Create activation extractor[39;00m
[32m     15[39m         extractor = ModelActivationExtractor(model, model_name)
[32m---> 16[39m         extractor.register_hooks()
[32m     17[39m 
[32m     18[39m         [38;5;66;03m# Extract activations[39;00m



ipdb>  n


TypeError: 'Linear' object is not iterable
> [32m/var/folders/w5/wtxn2_3x6jgbtxczqlsckq2r0000gp/T/ipykernel_89467/2842313008.py[39m([92m16[39m)[36m<module>[39m[34m()[39m
[32m     14[39m         [38;5;66;03m# Create activation extractor[39;00m
[32m     15[39m         extractor = ModelActivationExtractor(model, model_name)
[32m---> 16[39m         extractor.register_hooks()
[32m     17[39m 
[32m     18[39m         [38;5;66;03m# Extract activations[39;00m



In [None]:
# Visualization of results
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

# Plot behavioral similarity matrix
im1 = axes[0].imshow(behavioral_matrix, cmap='viridis', vmin=0, vmax=1)
axes[0].set_title('Behavioral Similarity Matrix')
axes[0].set_xlabel('State Index')
axes[0].set_ylabel('State Index')
plt.colorbar(im1, ax=axes[0])

# Plot RSA matrices for different models/layers
plot_idx = 1
for model_name, model_results in rsa_results.items():
    if plot_idx >= len(axes):
        break
        
    # Get the first layer's RSA matrix for visualization
    first_layer = list(model_results.keys())[0]
    rsa_matrix = model_results[first_layer]['rsa_matrix']
    
    im = axes[plot_idx].imshow(rsa_matrix, cmap='viridis', vmin=-1, vmax=1)
    axes[plot_idx].set_title(f'{model_name.upper()} - {first_layer}')
    axes[plot_idx].set_xlabel('State Index')
    axes[plot_idx].set_ylabel('State Index')
    plt.colorbar(im, ax=axes[plot_idx])
    plot_idx += 1

# Hide unused subplots
for i in range(plot_idx, len(axes)):
    axes[i].set_visible(False)

plt.tight_layout()
plt.show()

In [None]:
# Summary table of correlations
print("\n=== RSA CORRELATION SUMMARY ===")
print(f"{'Model':<10} {'Layer':<25} {'Pearson':<10} {'Spearman':<10} {'Sim Mean':<10} {'Dissim Mean':<12} {'P-value':<10}")
print("-" * 90)

for model_name, model_results in rsa_results.items():
    for layer_name, layer_results in model_results.items():
        pearson = layer_results['pearson_corr']
        spearman = layer_results['spearman_corr']
        group_analysis = layer_results['group_analysis']
        
        sim_mean = group_analysis['similar_mean']
        dissim_mean = group_analysis['dissimilar_mean']
        p_val = group_analysis.get('p_val', np.nan)
        
        print(f"{model_name:<10} {layer_name:<25} {pearson:<10.4f} {spearman:<10.4f} {sim_mean:<10.4f} {dissim_mean:<12.4f} {p_val:<10.4f}")

In [None]:
# Bar plot comparing similar vs dissimilar groups across models
fig, axes = plt.subplots(1, len(rsa_results), figsize=(5*len(rsa_results), 6))
if len(rsa_results) == 1:
    axes = [axes]

for idx, (model_name, model_results) in enumerate(rsa_results.items()):
    # Get first layer for visualization
    first_layer = list(model_results.keys())[0]
    group_analysis = model_results[first_layer]['group_analysis']
    
    means = [group_analysis['similar_mean'], group_analysis['dissimilar_mean']]
    stds = [group_analysis['similar_std'], group_analysis['dissimilar_std']]
    labels = ['Behaviorally\nSimilar', 'Behaviorally\nDissimilar']
    
    bars = axes[idx].bar(labels, means, yerr=stds, capsize=5, 
                        color=['skyblue', 'lightcoral'], alpha=0.7)
    axes[idx].set_title(f'{model_name.upper()}\n{first_layer}')
    axes[idx].set_ylabel('Neural Similarity')
    axes[idx].grid(True, alpha=0.3)
    
    # Add significance test result
    if 'p_val' in group_analysis:
        p_val = group_analysis['p_val']
        if p_val < 0.001:
            sig_text = '***'
        elif p_val < 0.01:
            sig_text = '**'
        elif p_val < 0.05:
            sig_text = '*'
        else:
            sig_text = 'n.s.'
        
        # Add significance annotation
        y_max = max(means) + max(stds)
        axes[idx].text(0.5, y_max * 1.1, sig_text, ha='center', va='bottom', fontsize=14)
        axes[idx].text(0.5, y_max * 1.15, f'p={p_val:.3f}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()