# Policy Sampling for FDA Analysis

This notebook collects policy behavior samples from π^H and π^R for downstream FDA analysis.

**Target:** 1M samples
**Estimated time:** ~20-30 min on A100

## Output Format
```
policy_samples.npz
├── observations        (N, 480)  bool     — Raw PGX observations (cast to float32 for inference)
├── covariates          (N, 48)   float32  — 48 statistical features
├── pi_H                (N, 38)   float32  — Smoothed π^H probabilities
├── pi_R                (N, 38)   float32  — Smoothed π^R probabilities
├── legal_masks         (N, 38)   bool     — Legal action masks
├── episode_ids         (N,)      int32    — Episode index (for cluster bootstrap)
├── board_ids           (N,)      int32    — Unique board index (without-replacement sampling)
└── timestep_in_episode (N,)      int16    — Bidding round within episode

metadata.json
├── action_names        — 38 action names: ["Pass", "Dbl", "Rdbl", "1C", ..., "7NT"]
├── ref_action          — ALR reference: "Pass"
├── ref_action_idx      — ALR reference index: 0 (Pass is always action 0 in PGX)
├── action_legal_rates  — Per-action legality rates
├── rare_actions        — Actions with legal_rate < 0.5%
└── states_per_episode  — {min, median, max, mean}
```

## Key Design Choices
1. **Additive Smoothing**: `(p + ε) / (1 + K*ε)` ensures p > 0 for ALR transform
2. **Without-replacement**: Uses `state._hand` fingerprint to avoid duplicate boards
3. **Cluster Bootstrap**: Save `episode_ids` for non-i.i.d. inference in Step 4
4. **Pass = action 0**: PGX encoding order (Pass, Dbl, Rdbl, 1C, ..., 7NT)

## 1. Setup Environment

### ⚠️ IMPORTANT: Two-Phase Setup

**Phase 1 (First time only):**
1. Run cell 2-3 to install dependencies
2. Restart runtime when prompted
3. **DO NOT run cells 2-3 again after restart**

**Phase 2 (After restart):**
1. Skip cells 2-3
2. Start from cell 4 and run all remaining cells

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Clean up old JAX cuda plugin and install dependencies
!pip uninstall -y jax-cuda12-plugin jax-cuda12-pjrt 2>/dev/null
!pip install -q pgx==2.4.2 dm-haiku optax

print("Setup complete!")

In [None]:
# Verify JAX GPU
import jax
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone or copy project
# Option 1: Clone from GitHub (if public)
# !git clone https://github.com/your-repo/bridge_bidding_interpretability.git

# Option 2: Copy from Drive
# Adjust path as needed
PROJECT_PATH = "/content/drive/MyDrive/bridge_bidding_interpretability"

import os
if os.path.exists(PROJECT_PATH):
    %cd {PROJECT_PATH}
    print(f"Using project at: {PROJECT_PATH}")
else:
    print(f"Project not found at {PROJECT_PATH}")
    print("Please upload the project or adjust PROJECT_PATH")

In [None]:
# Add project to path
import sys
sys.path.insert(0, PROJECT_PATH)

## 2. Load Models & Environment

In [None]:
from pathlib import Path
from pgx.bridge_bidding import BridgeBidding

from src.policy_loader import PolicyWrapper
from src.features.feature_extractor import BridgeFeatureExtractor
from src.sampling.sampler import PolicySampler, SamplingConfig

In [None]:
# Paths
PROJECT_ROOT = Path(PROJECT_PATH)
DDS_PATH = PROJECT_ROOT / "data" / "raw" / "dds_results" / "dds_results_100K_eval.npy"
PI_H_PATH = PROJECT_ROOT / "checkpoints" / "pi_H"
PI_R_PATH = PROJECT_ROOT / "checkpoints" / "pi_R"

# Verify paths exist
print(f"DDS exists: {DDS_PATH.exists()}")
print(f"π^H exists: {PI_H_PATH.exists()}")
print(f"π^R exists: {PI_R_PATH.exists()}")

In [None]:
# Load environment
print("Loading environment...")
env = BridgeBidding(dds_results_table_path=str(DDS_PATH))
print("Environment loaded!")

In [None]:
# Load policies
print("Loading π^H (Human proxy)...")
pi_H = PolicyWrapper(
    PI_H_PATH,
    model_type="DeepMind",
    activation="relu",
    model_file="model-sl.pkl",
)
print(f"π^H loaded: {pi_H}")

print("\nLoading π^R (RL policy)...")
pi_R = PolicyWrapper(
    PI_R_PATH,
    model_type="DeepMind",
    activation="relu",
    model_file="model-pretrained-rl-with-fsp.pkl",
)
print(f"π^R loaded: {pi_R}")

In [None]:
# Create feature extractor
extractor = BridgeFeatureExtractor(normalize=False)
print(f"Feature extractor: {len(extractor.get_feature_names())} features")
print(f"Features: {extractor.get_feature_names()[:10]}...")

## 3. Fast Sampling with JIT

Using a simplified JIT-compiled sampling loop for better performance.

In [None]:
import numpy as np
import jax.numpy as jnp
from tqdm.auto import tqdm  # Auto picks notebook-friendly version

# Configuration
N_SAMPLES = 100_000
SEED = 42
SMOOTHING_EPSILON = 1e-5
RUN_ID = "100K_pi_H_v2"
OUTPUT_DIR = str(PROJECT_ROOT / "data" / "processed" / "policy_samples")

print(f"Configuration:")
print(f"  N_SAMPLES: {N_SAMPLES:,}")
print(f"  SEED: {SEED}")
print(f"  SMOOTHING_EPSILON: {SMOOTHING_EPSILON}")
print(f"  RUN_ID: {RUN_ID}")
print(f"  Action selection: π^H (realistic state distribution)")

In [None]:
# JIT compile env functions for speed
env_init = jax.jit(env.init)
env_step = jax.jit(env.step)

# Get feature names for later
FEATURE_NAMES = extractor.get_feature_names()
N_FEATURES = len(FEATURE_NAMES)
print(f"Features: {N_FEATURES} total")

def fast_sample_jit(n_samples, seed=42):
    """Fast sampling using π^H for action selection."""
    key = jax.random.PRNGKey(seed)
    
    # Warmup JIT
    print("Warming up JIT...")
    state = env_init(key)
    for _ in range(10):
        key, k1, k2 = jax.random.split(key, 3)
        state = env_init(k1)
        obs_f32 = state.observation.astype(jnp.float32)
        mask = state.legal_action_mask
        _ = pi_H.get_probs(obs_f32, mask)
        _ = pi_R.get_probs(obs_f32, mask)
        state = env_step(state, 0, k2)
    print("JIT warmup complete!")
    
    # Storage
    all_obs, all_masks, all_pi_H, all_pi_R = [], [], [], []
    all_ep_ids, all_timesteps, all_covariates = [], [], []
    total, ep_id = 0, 0
    
    pbar = tqdm(total=n_samples, desc="Sampling")
    
    while total < n_samples:
        key, init_key = jax.random.split(key)
        state = env_init(init_key)
        timestep = 0
        
        while not state.terminated:
            obs = state.observation
            mask = state.legal_action_mask
            obs_f32 = obs.astype(jnp.float32)
            
            # Get probabilities from both policies
            probs_H, _ = pi_H.get_probs(obs_f32, mask)
            probs_R, _ = pi_R.get_probs(obs_f32, mask)
            
            # Extract features using observation (not state)
            feature_dict = extractor.extract(obs)
            features = np.array([feature_dict[name] for name in FEATURE_NAMES], dtype=np.float32)
            
            # Store
            all_obs.append(np.array(obs))
            all_masks.append(np.array(mask))
            all_pi_H.append(np.array(probs_H))
            all_pi_R.append(np.array(probs_R))
            all_ep_ids.append(ep_id)
            all_timesteps.append(timestep)
            all_covariates.append(features)
            
            # Select action using π^H (key change from random!)
            key, act_key, step_key = jax.random.split(key, 3)
            action = int(jax.random.categorical(act_key, jnp.log(probs_H + 1e-10)))
            state = env_step(state, action, step_key)
            
            total += 1
            timestep += 1
            pbar.update(1)
            
            if total >= n_samples:
                break
        
        ep_id += 1
    
    pbar.close()
    
    # Apply smoothing: (p + ε) / (1 + K*ε)
    K = 38  # number of actions
    pi_H_arr = np.stack(all_pi_H[:n_samples])
    pi_R_arr = np.stack(all_pi_R[:n_samples])
    pi_H_smooth = (pi_H_arr + SMOOTHING_EPSILON) / (1 + K * SMOOTHING_EPSILON)
    pi_R_smooth = (pi_R_arr + SMOOTHING_EPSILON) / (1 + K * SMOOTHING_EPSILON)
    
    return {
        'observations': np.stack(all_obs[:n_samples]),
        'legal_masks': np.stack(all_masks[:n_samples]),
        'pi_H': pi_H_smooth.astype(np.float32),
        'pi_R': pi_R_smooth.astype(np.float32),
        'episode_ids': np.array(all_ep_ids[:n_samples], dtype=np.int32),
        'timestep_in_episode': np.array(all_timesteps[:n_samples], dtype=np.int16),
        'covariates': np.stack(all_covariates[:n_samples]),
    }

print("fast_sample_jit defined!")

## 4. Run Sampling

In [None]:
# Run sampling!
print("=" * 60)
print(f"Starting sampling: {N_SAMPLES:,} samples")
print(f"Action selection: π^H")
print("=" * 60)

samples = fast_sample_jit(N_SAMPLES, seed=SEED)

print(f"\nSampling complete!")
print(f"  Total samples: {len(samples['episode_ids']):,}")
print(f"  Unique episodes: {len(np.unique(samples['episode_ids'])):,}")

In [None]:
# Quick sanity check - auction level distribution
auction_level_idx = FEATURE_NAMES.index('auction_level')
auction_levels = samples['covariates'][:, auction_level_idx]

print("Auction level distribution:")
for level in range(8):
    pct = np.mean(auction_levels == level) * 100
    if pct > 0.1:
        print(f"  Level {level}: {pct:.1f}%")

# This should show most samples at levels 0-4, NOT mostly 7!
if np.mean(auction_levels >= 6) > 0.3:
    print("\n⚠️ WARNING: Too many high-level contracts! Check action selection.")
else:
    print("\n✓ Auction distribution looks realistic!")

## 5. Verify & Save

In [None]:
# Verify samples
print("Verifying samples...")

# Basic checks
n_samples = len(samples['episode_ids'])
n_episodes = len(np.unique(samples['episode_ids']))

print(f"\n1. Sample counts:")
print(f"   Total samples: {n_samples:,}")
print(f"   Unique episodes: {n_episodes:,}")
print(f"   Avg states/episode: {n_samples/n_episodes:.1f}")

print(f"\n2. Probability checks:")
# Check probabilities sum to 1
pi_H_sums = np.sum(samples['pi_H'], axis=1)
pi_R_sums = np.sum(samples['pi_R'], axis=1)
print(f"   π^H sum range: [{pi_H_sums.min():.6f}, {pi_H_sums.max():.6f}]")
print(f"   π^R sum range: [{pi_R_sums.min():.6f}, {pi_R_sums.max():.6f}]")

# Check min probability (should be ~SMOOTHING_EPSILON due to smoothing)
min_prob_H = samples['pi_H'].min()
min_prob_R = samples['pi_R'].min()
print(f"   π^H min prob: {min_prob_H:.2e} (expected ~{SMOOTHING_EPSILON:.0e})")
print(f"   π^R min prob: {min_prob_R:.2e} (expected ~{SMOOTHING_EPSILON:.0e})")

print(f"\n3. Legal mask consistency:")
# Check that legal actions have reasonable probabilities
legal_probs_H = samples['pi_H'][samples['legal_masks']]
illegal_probs_H = samples['pi_H'][~samples['legal_masks']]
print(f"   Legal action prob range: [{legal_probs_H.min():.6f}, {legal_probs_H.max():.6f}]")
print(f"   Illegal action prob range: [{illegal_probs_H.min():.6f}, {illegal_probs_H.max():.6f}]")

print(f"\n4. Shape verification:")
for key, arr in samples.items():
    print(f"   {key}: {arr.shape} ({arr.dtype})")

print("\n✓ Verification complete!")

In [None]:
# Save samples
import json
import os

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Save .npz
npz_path = os.path.join(OUTPUT_DIR, f"{RUN_ID}_policy_samples.npz")
np.savez_compressed(npz_path, **samples)
print(f"Saved samples to: {npz_path}")

# Create and save metadata
action_names = ["Pass", "Dbl", "Rdbl", 
                "1C", "1D", "1H", "1S", "1NT",
                "2C", "2D", "2H", "2S", "2NT",
                "3C", "3D", "3H", "3S", "3NT",
                "4C", "4D", "4H", "4S", "4NT",
                "5C", "5D", "5H", "5S", "5NT",
                "6C", "6D", "6H", "6S", "6NT",
                "7C", "7D", "7H", "7S", "7NT"]

# Compute per-action legal rates
legal_rates = np.mean(samples['legal_masks'].astype(float), axis=0).tolist()

# Find rare actions (< 0.5% legal rate)
rare_actions = [action_names[i] for i, rate in enumerate(legal_rates) if rate < 0.005]

# Episode statistics
ep_counts = np.bincount(samples['episode_ids'])
states_per_episode = {
    'min': int(ep_counts.min()),
    'max': int(ep_counts.max()),
    'median': float(np.median(ep_counts)),
    'mean': float(np.mean(ep_counts)),
}

metadata = {
    'n_samples': int(n_samples),
    'n_episodes': int(n_episodes),
    'feature_names': FEATURE_NAMES,  # Use the variable we defined
    'action_names': action_names,
    'ref_action': 'Pass',
    'ref_action_idx': 0,
    'action_legal_rates': legal_rates,
    'rare_actions': rare_actions,
    'states_per_episode': states_per_episode,
    'sampling_config': {
        'seed': SEED,
        'smoothing_epsilon': SMOOTHING_EPSILON,
        'behavior_policy': 'pi_H',
    }
}

json_path = os.path.join(OUTPUT_DIR, f"{RUN_ID}_metadata.json")
with open(json_path, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"Saved metadata to: {json_path}")

print(f"\n✓ All files saved to {OUTPUT_DIR}")

In [None]:
# Summary statistics
print("=" * 60)
print("Sample Summary")
print("=" * 60)

print(f"\nπ^H statistics:")
entropy_H = -np.mean(np.sum(samples['pi_H'] * np.log(samples['pi_H'] + 1e-10), axis=1))
print(f"  Mean entropy: {entropy_H:.3f}")
print(f"  Max prob range: [{np.min(np.max(samples['pi_H'], axis=1)):.3f}, {np.max(np.max(samples['pi_H'], axis=1)):.3f}]")

print(f"\nπ^R statistics:")
entropy_R = -np.mean(np.sum(samples['pi_R'] * np.log(samples['pi_R'] + 1e-10), axis=1))
print(f"  Mean entropy: {entropy_R:.3f}")
print(f"  Max prob range: [{np.min(np.max(samples['pi_R'], axis=1)):.3f}, {np.max(np.max(samples['pi_R'], axis=1)):.3f}]")

print(f"\nAction legal rates:")
for i, name in enumerate(action_names[:10]):
    print(f"  {name}: {legal_rates[i]*100:.1f}%")

print(f"\nCovariate summary (first 5 features):")
for i, name in enumerate(FEATURE_NAMES[:5]):
    values = samples['covariates'][:, i]
    print(f"  {name}: mean={np.mean(values):.2f}, std={np.std(values):.2f}")

## 6. Download (Optional)

If you want to download the file to your local machine:

In [None]:
# Optional: Download to local machine
# from google.colab import files
# files.download(str(output_path))

## Done!

The policy samples are saved to:
```
data/processed/policy_samples/<run_id>_policy_samples.npz
data/processed/policy_samples/<run_id>_metadata.json
```

### Output Arrays
| Array | Shape | Dtype | Description |
|-------|-------|-------|-------------|
| observations | (N, 480) | bool | Raw PGX observations |
| covariates | (N, 48) | float32 | 48 statistical features |
| pi_H | (N, 38) | float32 | Smoothed π^H probabilities |
| pi_R | (N, 38) | float32 | Smoothed π^R probabilities |
| legal_masks | (N, 38) | bool | Legal action masks |
| episode_ids | (N,) | int32 | Episode index (for cluster bootstrap) |
| board_ids | (N,) | int32 | Unique board index |
| timestep_in_episode | (N,) | int16 | Bidding round in episode |

### Critical Metadata
- `ref_action_idx = 0` (Pass is always action 0 in PGX)
- `action_names = ["Pass", "Dbl", "Rdbl", "1C", ..., "7NT"]`

Next step: Use these samples for FDA analysis in Step 4.