# Environment Deep Dive

This notebook covers the Essence Wars environment API in detail.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/christianwissmann85/ai-cardgame/blob/master/notebooks/02_environment.ipynb)

In [None]:
# Colab setup (uncomment if needed)
# !curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
# import os; os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"
# !pip install git+https://github.com/christianwissmann85/ai-cardgame.git

In [None]:
# Setup: Change to repo root directory (required for data files)
import os
from pathlib import Path


def find_repo_root():
    path = Path.cwd()
    while path != path.parent:
        if (path / 'data' / 'cards').exists():
            return path
        path = path.parent
    return None

repo_root = find_repo_root()
if repo_root:
    os.chdir(repo_root)
    print(f"Working directory: {os.getcwd()}")
else:
    print("Warning: Could not find repo root.")

In [None]:
import numpy as np

from essence_wars._core import STATE_TENSOR_SIZE, PyGame, PyParallelGames

## PyGame API Reference

### Constructor

In [None]:
# Full constructor with all options
game = PyGame(
    deck1="vex_piercing",         # Player 1's deck (Argentum)
    deck2="alpha_frenzy",         # Player 2's deck (Symbiote)
    game_mode="attrition"          # "attrition" or "essence_duel"
)

# Default constructor (uses default decks)
game_default = PyGame()

### Core Methods

In [None]:
# Reset with seed for reproducibility
game.reset(seed=42)

# Get observation (state tensor)
obs = game.observe()  # Returns numpy array (326,)
print(f"Observation shape: {np.array(obs).shape}")

# Get legal action mask
mask = game.action_mask()  # Returns numpy array (256,), 1.0=legal, 0.0=illegal
print(f"Legal actions: {int(np.sum(mask))}")

# Take an action
action = 255  # EndTurn is always legal
reward, done = game.step(action)
print(f"Reward: {reward}, Done: {done}")

### State Queries

In [None]:
game.reset(seed=42)

print(f"Current player: {game.current_player()}")
print(f"Turn number: {game.turn_number()}")
print(f"Game over: {game.is_done()}")
print(f"Player 0 reward: {game.get_reward(0)}")
print(f"Player 1 reward: {game.get_reward(1)}")

### Built-in Agents

In [None]:
game.reset(seed=42)

# Random agent - uniform selection from legal actions
random_action = game.random_action()
print(f"Random action: {random_action}")

# Greedy agent - heuristic evaluation
greedy_action = game.greedy_action()
print(f"Greedy action: {greedy_action}")

# MCTS agent - tree search (configurable simulations)
mcts_action = game.mcts_action(simulations=100)
print(f"MCTS action (100 sims): {mcts_action}")

### State Cloning (for MCTS)

In [None]:
game.reset(seed=42)

# Fork creates an independent copy
game_copy = game.fork()

# Modify original
game.step(game.greedy_action())

# Copy is unaffected
print(f"Original turn: {game.turn_number()}")
print(f"Copy turn: {game_copy.turn_number()}")

### Static Methods

In [None]:
# List available decks
decks = PyGame.list_decks()
print(f"Available decks ({len(decks)}):")
for d in decks:
    print(f"  - {d}")

## Vectorized Environments

`PyParallelGames` runs multiple games in parallel for training.

In [None]:
# Create 8 parallel environments
num_envs = 8
envs = PyParallelGames(
    num_envs=num_envs,
    deck1="vex_piercing",
    deck2="alpha_frenzy",
)

# Reset all with different seeds
seeds = list(range(100, 100 + num_envs))
envs.reset(seeds)

print(f"Number of environments: {envs.num_envs}")

In [None]:
# Batch observations
obs_batch = envs.observe_batch()  # Shape: (num_envs, 326)
print(f"Observation batch shape: {np.array(obs_batch).shape}")

# Batch action masks
mask_batch = envs.action_mask_batch()  # Shape: (num_envs, 256)
print(f"Action mask batch shape: {np.array(mask_batch).shape}")

In [None]:
# Batch step - take EndTurn in all envs
actions = np.array([255] * num_envs, dtype=np.uint8)
rewards, dones = envs.step_batch(actions)

print(f"Rewards: {np.array(rewards)}")
print(f"Dones: {np.array(dones)}")

In [None]:
# Query state
print(f"Current players: {envs.current_players()}")
print(f"Done flags: {envs.dones()}")

# Reset single environment
envs.reset_single(idx=0, seed=999)

# Reset all with base seed
envs.reset_all(base_seed=1000)  # Seeds: 1000, 1001, ..., 1007

## Action Space Details

### Action Types

In [None]:
def describe_action_space():
    """Print action space documentation."""
    print("=" * 60)
    print("ACTION SPACE (256 discrete actions)")
    print("=" * 60)
    print()
    print("PlayCard (indices 0-99):")
    print("  Formula: hand_index * 10 + target_slot")
    print("  - hand_index: 0-9 (card position in hand)")
    print("  - target_slot: 0-4 (creature slot) or 5-6 (support slot)")
    print("  - For spells: target can be creature/player")
    print()
    print("Attack (indices 100-149):")
    print("  Formula: 100 + attacker_slot * 10 + target")
    print("  - attacker_slot: 0-4 (your creature slot)")
    print("  - target: 0-4 (enemy creature) or 5 (enemy face)")
    print()
    print("UseAbility (indices 150-249):")
    print("  Formula: 150 + slot * 20 + ability_index * 10 + target")
    print("  - slot: 0-4 (creature with ability)")
    print("  - ability_index: 0-1 (which ability)")
    print("  - target: 0-9 (depends on ability)")
    print()
    print("EndTurn (index 255):")
    print("  Always legal. Passes to opponent.")
    print()

describe_action_space()

In [None]:
def decode_action(idx: int) -> str:
    """Convert action index to human-readable description."""
    if idx == 255:
        return "EndTurn"
    elif idx < 100:
        hand = idx // 10
        target = idx % 10
        return f"PlayCard(hand[{hand}] -> slot {target})"
    elif idx < 150:
        i = idx - 100
        attacker = i // 10
        target = i % 10
        target_str = f"creature[{target}]" if target < 5 else "face"
        return f"Attack(slot[{attacker}] -> {target_str})"
    elif idx < 250:
        i = idx - 150
        slot = i // 20
        ability = (i % 20) // 10
        target = i % 10
        return f"UseAbility(slot[{slot}].ability[{ability}] -> {target})"
    else:
        return f"Reserved({idx})"

# Example decoding
examples = [0, 5, 42, 100, 105, 115, 150, 170, 255]
for idx in examples:
    print(f"{idx:3d} -> {decode_action(idx)}")

## State Tensor Layout

In [None]:
def describe_tensor_layout():
    """Print tensor layout documentation."""
    print("=" * 60)
    print(f"STATE TENSOR ({STATE_TENSOR_SIZE} floats)")
    print("=" * 60)
    print()

    idx = 0

    print(f"[{idx:3d}-{idx+5:3d}] Global state (6 floats):")
    print("         [0] turn / 30")
    print("         [1] current_player (0 or 1)")
    print("         [2] game_over (0 or 1)")
    print("         [3] winner (-1, 0, or 1)")
    print("         [4-5] reserved")
    idx += 6

    for p in [1, 2]:
        print(f"\n[{idx:3d}-{idx+74:3d}] Player {p} state (75 floats):")
        print(f"         [{idx}] life / 20")
        print(f"         [{idx+1}] essence / 10")
        print(f"         [{idx+2}] action_points / 3")
        print(f"         [{idx+3}] deck_size / 30")
        print(f"         [{idx+4}] hand_size / 10")
        print(f"         [{idx+5}-{idx+14}] hand card IDs (10 slots)")
        print(f"         [{idx+15}-{idx+64}] creatures (5 slots x 10 floats)")
        print(f"         [{idx+65}-{idx+74}] supports (2 slots x 5 floats)")
        idx += 75

    print(f"\n[{idx:3d}-{STATE_TENSOR_SIZE-1:3d}] Card embedding IDs ({STATE_TENSOR_SIZE - idx} floats)")

describe_tensor_layout()

### Creature Slot Encoding (10 floats each)

In [None]:
print("Creature slot encoding (10 floats):")
print("  [0] occupied (0 or 1)")
print("  [1] attack / 10")
print("  [2] health / 10")
print("  [3] max_health / 10")
print("  [4] can_attack (0 or 1)")
print("  [5] exhausted (0 or 1)")
print("  [6] silenced (0 or 1)")
print("  [7] has_rush (0 or 1)")
print("  [8] has_guard (0 or 1)")
print("  [9] keywords_bitfield / 65535")

## Performance Benchmarks

In [None]:
import time

# Benchmark single env
game = PyGame()
game.reset(seed=42)

n_steps = 10000
start = time.perf_counter()
for _ in range(n_steps):
    game.observe()
elapsed = time.perf_counter() - start
print(f"Observations: {n_steps / elapsed:.0f} per second")

start = time.perf_counter()
for _ in range(n_steps):
    game.action_mask()
elapsed = time.perf_counter() - start
print(f"Action masks: {n_steps / elapsed:.0f} per second")

In [None]:
# Benchmark vectorized env
num_envs = 64
envs = PyParallelGames(num_envs=num_envs)
envs.reset_all(base_seed=0)

n_batches = 1000
start = time.perf_counter()
for _ in range(n_batches):
    envs.observe_batch()
elapsed = time.perf_counter() - start
print(f"Batch observations ({num_envs} envs): {n_batches * num_envs / elapsed:.0f} per second")

## Next Steps

- **03_dataset_exploration.ipynb** - Explore MCTS training data
- **04_behavioral_cloning.ipynb** - Train a neural network agent