In [1]:
import numpy as np
import gymnasium as gym

In [2]:
from typing import Optional, Dict, Any, List, Callable

In [3]:
# print("ALE version:", ale_py.__version__)
# gym.register_envs(ale_py)

In [4]:
import ale_py


class PongSymmetryAnalyzer:
    """
    A class for analyzing symmetries in Pong environments using RAM observations.
    """
    
    def __init__(self, render_mode: Optional[str] = None):
        """
        Initialize the Pong environment with RAM observations.
        
        Args:
            render_mode: Rendering mode for the environment (None, "human", etc.)
        """
        # Register ALE environments
        gym.register_envs(ale_py)
        
        # Create environment with RAM observations
        self.env = gym.make(
            "ALE/Pong-v5",
            obs_type="ram",
            render_mode=render_mode,
        )
        
        # RAM indices for extracting game state
        self.PONG_RAM_INDEX = {
            "ball_x": 49,
            "ball_y": 54,
            "enemy_y": 50,
            "player_y": 51,
        }
        
        # Game boundaries
        self.BALL_X_MIN = 50
        self.BALL_X_MAX = 208
        self.BALL_Y_MIN = 44
        self.BALL_Y_MAX = 207
        self.PLAYER_Y_MIN = 38
        self.PLAYER_Y_MAX = 203
        self.ENEMY_Y_MIN = 0
        self.ENEMY_Y_MAX = 208
        
        self.BALL_X_MID = (self.BALL_X_MAX - self.BALL_X_MIN) / 2
        self.BALL_Y_MID = (self.BALL_Y_MAX - self.BALL_Y_MIN) / 2

        self.PLAYER_Y_MID = (self.PLAYER_Y_MAX - self.PLAYER_Y_MIN) / 2
        self.ENEMY_Y_MID = (self.ENEMY_Y_MAX - self.ENEMY_Y_MIN) / 2
        
        # Store sampled states
        self.sampled_states: List[Dict[str, Any]] = []
        
    def ram_to_logic_state(
        self, 
        ram: np.ndarray, 
        prev_state: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """
        Convert RAM observation to logical game state.
        
        Args:
            ram: 128-byte RAM vector from the environment
            prev_state: Previous logical state for velocity calculation
            
        Returns:
            Dictionary containing ball position, velocity, and paddle positions
        """
        # Extract raw positions from RAM
        ball_x_raw = int(ram[self.PONG_RAM_INDEX["ball_x"]])
        ball_y_raw = int(ram[self.PONG_RAM_INDEX["ball_y"]])
        enemy_y_raw = int(ram[self.PONG_RAM_INDEX["enemy_y"]])
        player_y_raw = int(ram[self.PONG_RAM_INDEX["player_y"]])
        
        # Check if ball exists (OCAtari condition)
        ball_exists = (ball_y_raw != 0) and (ball_x_raw > 49)
        
        # Set ball position (None if ball doesn't exist)
        ball_x = ball_x_raw if ball_exists else None
        ball_y = ball_y_raw if ball_exists else None
        
        # Paddle positions
        player_y = player_y_raw
        enemy_y = enemy_y_raw
        
        # Calculate velocity via finite differences
        if (prev_state is not None and 
            prev_state.get("ball_x") is not None and 
            ball_x is not None):
            ball_dx = ball_x - prev_state["ball_x"]
            ball_dy = ball_y - prev_state["ball_y"]
        else:
            ball_dx = 0
            ball_dy = 0
            
        return {
            "ball_x": ball_x,
            "ball_y": ball_y,
            "ball_dx": ball_dx,
            "ball_dy": ball_dy,
            "player_y": player_y,
            "enemy_y": enemy_y,
        }
    
    def sample_states(
        self, 
        num_episodes: int = 5, 
        max_steps_per_episode: int = 1000,
        model: Optional[Any] = None,
        max_states: Optional[int] = None
    ) -> List[Dict[str, Any]]:
        """
        Sample states from the environment using either a model or random policy.
        
        Args:
            num_episodes: Number of episodes to run
            max_steps_per_episode: Maximum steps per episode
            model: Optional model with select_action method. If None, uses random policy
            max_states: Maximum number of states to collect. If None, no limit
            
        Returns:
            List of logical game states
        """
        self.sampled_states = []
        
        for ep in range(num_episodes):
            ram, info = self.env.reset()
            prev_state = None
            
            for t in range(max_steps_per_episode):
                # Convert RAM to logical state
                state = self.ram_to_logic_state(ram, prev_state=prev_state)
                self.sampled_states.append(state)
                prev_state = state
                
                # Check if we've reached the maximum number of states
                if max_states is not None and len(self.sampled_states) >= max_states:
                    return self.sampled_states
                
                # Select action using model or random policy
                if model is not None:
                    # Assume model has a select_action method
                    if hasattr(model, 'select_action'):
                        action = model.select_action(ram)
                    elif hasattr(model, 'predict'):
                        action = model.predict(ram)
                    else:
                        # Try calling the model directly
                        action = model(ram)
                else:
                    # Random policy
                    action = self.env.action_space.sample()
                
                # Take step in environment
                ram, reward, terminated, truncated, info = self.env.step(action)
                done = terminated or truncated
                
                if done:
                    break
                    
        return self.sampled_states
    
    def naive_symmetry(self, state_1: Dict[str, Any], state_2: Dict[str, Any]) -> bool:
        """
        Check if two states are symmetric under naive symmetry assumption.
        Ignores scores and opponent paddle position, bins coordinates.
        
        Args:
            state_1: First game state
            state_2: Second game state
            
        Returns:
            True if states are considered symmetric, False otherwise
        """
        # Check if ball x positions and x velocities match
        ballx = state_1["ball_x"] == state_2["ball_x"]
        balldx = state_1["ball_dx"] == state_2["ball_dx"]
        
        if not ballx or not balldx:
            return False
            
        bally = state_1["ball_y"] == state_2["ball_y"]
        balldy = state_1["ball_dy"] == state_2["ball_dy"]
        playery = state_1["player_y"] == state_2["player_y"]
        
        # Case 1: Completely equal
        if bally and balldy and playery:
            return True
        
        # Case 2: Ball off screen, paddle positions symmetric
        if state_1["ball_y"] is None and state_2["ball_y"] is None:
            if (abs(state_1["player_y"] - self.BALL_Y_MID) == 
                abs(state_2["player_y"] - self.BALL_Y_MID)):
                return True
        
        # Case 3: Symmetric reflection about y-axis
        if (state_1["ball_y"] is not None and state_2["ball_y"] is not None):
            ball_y_symmetric = (abs(state_1["ball_y"] - self.BALL_Y_MID) == 
                              abs(state_2["ball_y"] - self.BALL_Y_MID))
            player_y_symmetric = (abs(state_1["player_y"] - self.PLAYER_Y_MID) == 
                                abs(state_2["player_y"] - self.PLAYER_Y_MID))
            ball_dy_opposite = state_1["ball_dy"] == -state_2["ball_dy"]
            
            if ball_y_symmetric and player_y_symmetric and ball_dy_opposite:
                return True
                
        return False
    
    def generate_similarity_matrix(
        self, 
        states: Optional[List[Dict[str, Any]]] = None,
        symmetry_function: Optional[Callable] = None
    ) -> np.ndarray:
        """
        Generate a similarity matrix for the given states using a symmetry function.
        
        Args:
            states: List of states to compare. If None, uses self.sampled_states
            symmetry_function: Function to check symmetry between two states.
                             If None, uses self.naive_symmetry
                             
        Returns:
            Binary similarity matrix where entry (i,j) is 1 if states i and j are symmetric
        """
        if states is None:
            states = self.sampled_states
            
        if symmetry_function is None:
            symmetry_function = self.naive_symmetry
            
        if len(states) == 0:
            raise ValueError("No states provided for similarity matrix generation")
            
        matrix = np.zeros((len(states), len(states)), dtype=int)
        
        for i, state_i in enumerate(states):
            for j, state_j in enumerate(states):
                matrix[i][j] = int(symmetry_function(state_i, state_j))
                
        return matrix
    
    def get_similarity_stats(self, similarity_matrix: np.ndarray) -> Dict[str, float]:
        """
        Compute statistics about the similarity matrix.
        
        Args:
            similarity_matrix: Binary similarity matrix
            
        Returns:
            Dictionary with statistics about symmetries
        """
        n = similarity_matrix.shape[0]
        total_pairs = n * (n - 1) // 2  # Exclude diagonal
        
        # Count symmetric pairs (excluding diagonal)
        symmetric_pairs = (np.sum(similarity_matrix) - n) // 2
        
        return {
            "total_states": n,
            "total_pairs": total_pairs,
            "symmetric_pairs": symmetric_pairs,
            "symmetry_ratio": symmetric_pairs / total_pairs if total_pairs > 0 else 0.0,
            "diagonal_sum": np.sum(np.diag(similarity_matrix)),
        }
    
    def close(self):
        """Close the environment."""
        self.env.close()


# # Example usage
# if __name__ == "__main__":
#     # Initialize analyzer
#     analyzer = PongSymmetryAnalyzer()
    
#     # Sample states using random policy
#     states = analyzer.sample_states(
#         num_episodes=3, 
#         max_steps_per_episode=500,
#         max_states=100
#     )
    
#     print(f"Sampled {len(states)} states")
    
#     # Generate similarity matrix
#     similarity_matrix = analyzer.generate_similarity_matrix()
    
#     # Get statistics
#     stats = analyzer.get_similarity_stats(similarity_matrix)
#     print("Similarity Statistics:")
#     for key, value in stats.items():
#         print(f"  {key}: {value}")
    
#     # Close environment
#     analyzer.close()

Sampled 100 states
Similarity Statistics:
  total_states: 100
  total_pairs: 4950
  symmetric_pairs: 14
  symmetry_ratio: 0.0028282828282828283
  diagonal_sum: 100


A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)
[Powered by Stella]


In [5]:
analyzer = PongSymmetryAnalyzer()

In [9]:
states = analyzer.sample_states(
        num_episodes=3, 
        max_steps_per_episode=500,
        max_states=10000
    )

similarity_matrix = analyzer.generate_similarity_matrix([states[i] for i in np.random.choice(np.arange(len(states)), size=100)])

stats = analyzer.get_similarity_stats(similarity_matrix)
for key, value in stats.items():
    print(f"  {key}: {value}")

analyzer.close()

  total_states: 100
  total_pairs: 4950
  symmetric_pairs: 27
  symmetry_ratio: 0.005454545454545455
  diagonal_sum: 100


In [6]:
# Initialize analyzer
analyzer = PongSymmetryAnalyzer()

# Sample states using random policy
states = analyzer.sample_states(
    num_episodes=3, 
    max_steps_per_episode=500,
    max_states=100
)

print(f"Sampled {len(states)} states")

# Generate similarity matrix
similarity_matrix = analyzer.generate_similarity_matrix()

# Get statistics
stats = analyzer.get_similarity_stats(similarity_matrix)
print("Similarity Statistics:")
for key, value in stats.items():
    print(f"  {key}: {value}")

# Close environment
analyzer.close()

Sampled 100 states
Similarity Statistics:
  total_states: 100
  total_pairs: 4950
  symmetric_pairs: 8
  symmetry_ratio: 0.0016161616161616162
  diagonal_sum: 100


In [81]:
def naive_symmetry(state_1, state_2):
    """
    In this naive treatment, I ignore: scores, opponent paddle position.
    I bin the x, y coordinates of the ball.
    I bin the y coordinate of the player paddle.
    Function returns True if: 
        - coarse x is equal in both states
        - dx is equal in both states
        - AND ball y and paddle y are equal OR equal on reflection in y=0 line
    """
    ballx = state_1["ball_x"] == state_2["ball_x"]
    balldx = state_1["ball_dx"] == state_2["ball_dx"]

    if not ballx or not balldx:
        return False
        
    bally = state_1["ball_y"] == state_2["ball_y"]
    balldy = state_1["ball_dy"] == state_2["ball_dy"] 

    playery = state_1["player_y"] == state_2["player_y"] 

    # case 1: completely equal
    if bally:
        if balldy:
            if playery:
                return True

    # case 2: ball off screen. paddle flipped
    if state_1["ball_y"] is None:
        if abs(state_1["player_y"] - BALL_Y_MID) == abs(state_2["player_y"] - BALL_Y_MID):
            return True
        
    # case 2: flipped
    if state_1["ball_y"] is not None:
        if abs(state_1["ball_y"] - BALL_Y_MID) == abs(state_2["ball_y"] - BALL_Y_MID):
            if abs(state_1["player_y"] - BALL_Y_MID) == abs(state_2["player_y"] - BALL_Y_MID):
                if state_1["ball_dy"] == -state_2["ball_dy"]:
                    return True

    return False

In [82]:
def get_similarity_matrix(states):
    matrix = np.zeros((len(states), len(states)))
    for i, state_i in enumerate(states):
        for j, state_j in enumerate(states):
            matrix[i][j] = naive_symmetry(state_i, state_j)
    return matrix

In [83]:
sm = get_similarity_matrix(sample)

In [52]:
PONG_RAM_INDEX = {
    "ball_x":   49,
    "ball_y":   54,
    "enemy_y":  50,
    "player_y": 51,
}

BALL_X_MIN = 50
BALL_X_MAX = 208
BALL_Y_MIN = 44
BALL_Y_MAX = 207
PLAYER_Y_MIN = 38
PLAYER_Y_MAX = 203
ENEMY_Y_MIN = 0
ENEMY_Y_MAX = 208

BALL_X_MID = (BALL_X_MAX - BALL_X_MIN) / 2
BALL_Y_MID = (BALL_Y_MAX - BALL_Y_MIN) / 2

PLAYER_Y_MID = (PLAYER_Y_MAX - PLAYER_Y_MIN) / 2
ENEMY_Y_MID = (ENEMY_Y_MAX - ENEMY_Y_MIN) / 2

In [53]:
def ram_to_logic_state(
    ram: np.ndarray,
    prev_state: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    """
    Minimal logic-level state for ALE Pong from a 128-byte RAM vector.

    Returns:
        {
            "ball_x":   int or None,
            "ball_y":   int or None,
            "ball_dx":  int,
            "ball_dy":  int,
            "player_y": int or None,
            "enemy_y":  int or None,
        }
    """
    # Raw positions from RAM (as in OCAtari)
    ball_x_raw   = int(ram[PONG_RAM_INDEX["ball_x"]])
    ball_y_raw   = int(ram[PONG_RAM_INDEX["ball_y"]])
    enemy_y_raw  = int(ram[PONG_RAM_INDEX["enemy_y"]])
    player_y_raw = int(ram[PONG_RAM_INDEX["player_y"]])

    # OCAtari condition for “ball exists”
    ball_exists = (ball_y_raw != 0) and (ball_x_raw > 49)

    # If you want to treat "no ball" explicitly:
    ball_x = ball_x_raw if ball_exists else None
    ball_y = ball_y_raw if ball_exists else None

    # Paddles basically always exist when game is running;
    # if you want to mirror OCAtari, you could add checks similar to ram[50] / ram[51] ranges.
    player_y = player_y_raw
    enemy_y  = enemy_y_raw

    # Velocity via finite differences
    if prev_state is not None and prev_state.get("ball_x") is not None and ball_x is not None:
        ball_dx = ball_x - prev_state["ball_x"]
        ball_dy = ball_y - prev_state["ball_y"]
    else:
        ball_dx = 0
        ball_dy = 0

    return {
        "ball_x":   ball_x,
        "ball_y":   ball_y,
        "ball_dx":  ball_dx,
        "ball_dy":  ball_dy,
        "player_y": player_y,
        "enemy_y":  enemy_y,
    }

In [69]:
sample = [ram_states[i] for i in np.random.choice(np.arange(len(ram_states)), size=100)]

In [54]:
# env = gym.make("Pong-ram-v5", render_mode=None)  # or "human" for visual debug
# env = gym.make("ALE/Pong-ram-v5")
env = gym.make(
    "ALE/Pong-v5",
    obs_type="ram",     # <-- this is the key line
    render_mode=None,
)
obs, info = env.reset()
obs_logic_state = ram_to_logic_state(obs)

# print(obs_logic_state)

done = False
while not done:
    action = env.action_space.sample()  # 6 discrete actions
    obs, reward, terminated, truncated, info = env.step(action)
    obs_logic_state = ram_to_logic_state(obs, obs_logic_state)
    # print(obs_logic_state)
    done = terminated or truncated

env.close()

In [60]:
states = []
ram_states = []

num_episodes = 5          # tweak as you like
max_steps_per_ep = 10000    # to cap runtime

env = gym.make(
    "ALE/Pong-v5",
    obs_type="ram",     # <-- this is the key line
    render_mode=None,
)
ale = env.unwrapped.ale

for ep in range(num_episodes):
    ram, info = env.reset()
    prev_state = None
    ram_states.append(ram)
    states.append(ale.getScreenRGB())

    for t in range(max_steps_per_ep):
        state = ram_to_logic_state(ram, prev_state=prev_state)
        ram_states.append(state)
        states.append(ale.getScreenRGB())
        prev_state = state

        action = env.action_space.sample()
        step_out = env.step(action)

        ram, reward, terminated, truncated, info = step_out
        done = terminated or truncated
        # ram, reward, done, info = step_out

        if done:
            break

env.close()

In [32]:
ball_x_vals = []
ball_y_vals = []

player_y_vals = []
enemy_y_vals = []

num_episodes = 5000          # tweak as you like
max_steps_per_ep = 10000    # to cap runtime


def update_stats_from_state(state):
    if state["ball_x"] is not None and state["ball_y"] is not None:
        ball_x_vals.append(state["ball_x"])
        ball_y_vals.append(state["ball_y"])
    player_y_vals.append(state["player_y"])
    enemy_y_vals.append(state["enemy_y"])


# --- Run random rollouts and collect stats ---

for ep in range(num_episodes):
    ram, info = env.reset()
    prev_state = None

    for t in range(max_steps_per_ep):
        state = ram_to_logic_state(ram, prev_state=prev_state)
        update_stats_from_state(state)
        prev_state = state

        action = env.action_space.sample()
        step_out = env.step(action)

        ram, reward, terminated, truncated, info = step_out
        done = terminated or truncated
        # ram, reward, done, info = step_out

        if done:
            break

env.close()

# --- Compute and print stats ---

def summarize(name, values):
    if len(values) == 0:
        print(f"{name}: no data collected")
        return
    v = np.array(values)
    print(
        f"{name}: min={v.min()}, max={v.max()}, mean={v.mean():.2f}, "
        f"unique={len(np.unique(v))}"
    )

In [29]:
summarize("ball_x", ball_x_vals)
summarize("ball_y", ball_y_vals)
summarize("player_y", player_y_vals)
summarize("enemy_y", enemy_y_vals)

ball_x: min=50, max=207, mean=153.31, unique=158
ball_y: min=44, max=207, mean=133.28, unique=164
player_y: min=38, max=203, mean=111.68, unique=166
enemy_y: min=0, max=206, mean=117.40, unique=104


In [31]:
summarize("ball_x", ball_x_vals)
summarize("ball_y", ball_y_vals)
summarize("player_y", player_y_vals)
summarize("enemy_y", enemy_y_vals)

ball_x: min=50, max=208, mean=153.20, unique=159
ball_y: min=44, max=207, mean=132.10, unique=164
player_y: min=38, max=203, mean=111.76, unique=166
enemy_y: min=0, max=206, mean=115.50, unique=104


In [33]:
summarize("ball_x", ball_x_vals)
summarize("ball_y", ball_y_vals)
summarize("player_y", player_y_vals)
summarize("enemy_y", enemy_y_vals)

ball_x: min=50, max=208, mean=153.04, unique=159
ball_y: min=44, max=207, mean=132.23, unique=164
player_y: min=38, max=203, mean=111.60, unique=166
enemy_y: min=0, max=208, mean=115.74, unique=105
