# CFR vs Random: Cross-Play Analysis

Test CFR liberals against random fascists and vice versa to understand learned strategies.

In [None]:
import sys
sys.path.insert(0, "../..")

import pickle
from pathlib import Path
from collections import defaultdict

from shitler_env.game import ShitlerEnv
from agents.cfr.cfr_agent import CFRAgent, get_legal_actions
from agents.cfr.infoset import get_infoset_key

In [None]:
# Load the trained CFR agent
checkpoint_path = Path("checkpoints/cfr_iter_500000.pkl")
with open(checkpoint_path, "rb") as f:
    data = pickle.load(f)

# Convert to defaultdicts
regret_sums = defaultdict(lambda: defaultdict(float))
strategy_sums = defaultdict(lambda: defaultdict(float))
for k, v in data["regret_sums"].items():
    regret_sums[k] = v
for k, v in data["strategy_sums"].items():
    strategy_sums[k] = v

cfr_agent = CFRAgent()
cfr_agent.regret_sums = regret_sums
cfr_agent.strategy_sums = strategy_sums

print(f"Loaded {len(data['regret_sums']):,} infosets")

In [None]:
import random

ROLE_MAP = {"lib": 0, "fasc": 1, "hitty": 2}
ROLE_NAMES = {0: "Liberal", 1: "Fascist", 2: "Hitler"}

def get_roles_list(env):
    """Convert env.roles dict to list of numeric roles."""
    return [ROLE_MAP[env.roles[f"P{i}"]] for i in range(5)]

def get_random_action(legal_actions):
    """Random agent action."""
    return random.choice(legal_actions)

def get_cfr_action(cfr_agent, obs, phase, player_idx, legal_actions):
    """CFR agent action using average strategy."""
    infoset_key = get_infoset_key(obs, phase, player_idx)
    strategy = cfr_agent.get_average_strategy(infoset_key, legal_actions)
    return cfr_agent.sample_action(strategy)

In [None]:
def run_mixed_games(cfr_agent, num_games, cfr_team="liberal"):
    """
    Run games with one team using CFR and the other using random.
    
    Args:
        cfr_team: "liberal" or "fascist" - which team uses CFR
    """
    results = {
        "liberal_wins": 0,
        "fascist_wins": 0,
        "win_conditions": defaultdict(int),
        "game_lengths": []
    }
    
    for game_idx in range(num_games):
        env = ShitlerEnv()
        env.reset()
        roles = get_roles_list(env)
        
        # Determine which players use CFR
        if cfr_team == "liberal":
            cfr_players = {i for i, r in enumerate(roles) if r == 0}  # liberals
        else:
            cfr_players = {i for i, r in enumerate(roles) if r in [1, 2]}  # fascists + hitler
        
        moves = 0
        while not all(env.terminations.values()):
            current_agent = env.agent_selection
            current_idx = env.agents.index(current_agent)
            obs = env.observe(current_agent)
            phase = env.phase
            
            legal_actions = get_legal_actions(env, current_agent)
            if not legal_actions:
                env.step(0)
                continue
            
            # Choose action based on team
            if current_idx in cfr_players:
                action = get_cfr_action(cfr_agent, obs, phase, current_idx, legal_actions)
            else:
                action = get_random_action(legal_actions)
            
            env.step(action)
            moves += 1
        
        results["game_lengths"].append(moves)
        
        # Determine winner
        lib_idx = roles.index(0)  # first liberal
        lib_reward = env.rewards[f"P{lib_idx}"]
        
        if lib_reward > 0:
            results["liberal_wins"] += 1
            # Determine win condition
            if env.lib_policies >= 5:
                results["win_conditions"]["lib_5_policies"] += 1
            else:
                results["win_conditions"]["hitler_executed"] += 1
        else:
            results["fascist_wins"] += 1
            if env.fasc_policies >= 6:
                results["win_conditions"]["fasc_6_policies"] += 1
            else:
                results["win_conditions"]["hitler_chancellor"] += 1
    
    return results

def print_results(results, num_games, title):
    """Print formatted results."""
    print("=" * 60)
    print(title)
    print("=" * 60)
    print(f"Games: {num_games}")
    print()
    print(f"Liberal Win Rate:  {results['liberal_wins']/num_games*100:.1f}%")
    print(f"Fascist Win Rate:  {results['fascist_wins']/num_games*100:.1f}%")
    print()
    print("Win Conditions:")
    for cond, count in sorted(results["win_conditions"].items()):
        print(f"  {cond}: {count} ({count/num_games*100:.1f}%)")
    print()
    avg_len = sum(results["game_lengths"]) / len(results["game_lengths"])
    print(f"Avg Game Length: {avg_len:.1f} moves")

## Baseline: Random vs Random

In [None]:
def run_random_games(num_games):
    """Run games with all random players."""
    results = {
        "liberal_wins": 0,
        "fascist_wins": 0,
        "win_conditions": defaultdict(int),
        "game_lengths": []
    }
    
    for _ in range(num_games):
        env = ShitlerEnv()
        env.reset()
        roles = get_roles_list(env)
        
        moves = 0
        while not all(env.terminations.values()):
            current_agent = env.agent_selection
            legal_actions = get_legal_actions(env, current_agent)
            if not legal_actions:
                env.step(0)
                continue
            action = get_random_action(legal_actions)
            env.step(action)
            moves += 1
        
        results["game_lengths"].append(moves)
        
        lib_idx = roles.index(0)
        lib_reward = env.rewards[f"P{lib_idx}"]
        
        if lib_reward > 0:
            results["liberal_wins"] += 1
            if env.lib_policies >= 5:
                results["win_conditions"]["lib_5_policies"] += 1
            else:
                results["win_conditions"]["hitler_executed"] += 1
        else:
            results["fascist_wins"] += 1
            if env.fasc_policies >= 6:
                results["win_conditions"]["fasc_6_policies"] += 1
            else:
                results["win_conditions"]["hitler_chancellor"] += 1
    
    return results

NUM_GAMES = 500
random_results = run_random_games(NUM_GAMES)
print_results(random_results, NUM_GAMES, "RANDOM vs RANDOM (Baseline)")

## CFR Liberals vs Random Fascists

In [None]:
cfr_lib_results = run_mixed_games(cfr_agent, NUM_GAMES, cfr_team="liberal")
print_results(cfr_lib_results, NUM_GAMES, "CFR LIBERALS vs RANDOM FASCISTS")

## Random Liberals vs CFR Fascists

In [None]:
def run_cfr_games(cfr_agent, num_games):
    """Run games with all CFR players."""
    results = {
        "liberal_wins": 0,
        "fascist_wins": 0,
        "win_conditions": defaultdict(int),
        "game_lengths": []
    }
    
    for _ in range(num_games):
        env = ShitlerEnv()
        env.reset()
        roles = get_roles_list(env)
        
        moves = 0
        while not all(env.terminations.values()):
            current_agent = env.agent_selection
            current_idx = env.agents.index(current_agent)
            obs = env.observe(current_agent)
            phase = env.phase
            
            legal_actions = get_legal_actions(env, current_agent)
            if not legal_actions:
                env.step(0)
                continue
            
            action = get_cfr_action(cfr_agent, obs, phase, current_idx, legal_actions)
            env.step(action)
            moves += 1
        
        results["game_lengths"].append(moves)
        
        lib_idx = roles.index(0)
        lib_reward = env.rewards[f"P{lib_idx}"]
        
        if lib_reward > 0:
            results["liberal_wins"] += 1
            if env.lib_policies >= 5:
                results["win_conditions"]["lib_5_policies"] += 1
            else:
                results["win_conditions"]["hitler_executed"] += 1
        else:
            results["fascist_wins"] += 1
            if env.fasc_policies >= 6:
                results["win_conditions"]["fasc_6_policies"] += 1
            else:
                results["win_conditions"]["hitler_chancellor"] += 1
    
    return results

cfr_self_results = run_cfr_games(cfr_agent, NUM_GAMES)
print_results(cfr_self_results, NUM_GAMES, "CFR vs CFR (Self-Play)")

## CFR vs CFR (Self-Play)

In [None]:
def run_cfr_games(cfr_agent, num_games):
    """Run games with all CFR players."""
    results = {
        "liberal_wins": 0,
        "fascist_wins": 0,
        "win_conditions": defaultdict(int),
        "game_lengths": []
    }
    
    for _ in range(num_games):
        env = ShitlerEnv()
        env.reset()
        roles = get_roles_list(env)
        
        moves = 0
        while not all(env.terminations.values()):
            current_agent = env.agent_selection
            current_idx = env.agents.index(current_agent)
            obs = env.observe(current_agent)
            phase = env.phase
            
            legal_actions = get_legal_actions(env, current_agent)
            if not legal_actions:
                env.step(0)
                continue
            
            action = get_cfr_action(cfr_agent, obs, phase, current_idx, legal_actions)
            env.step(action)
            moves += 1
        
        results["game_lengths"].append(moves)
        
        lib_idx = roles.index(0)
        lib_reward = env.rewards[f"player_{lib_idx}"]
        
        if lib_reward > 0:
            results["liberal_wins"] += 1
            if env.lib_policies >= 5:
                results["win_conditions"]["lib_5_policies"] += 1
            else:
                results["win_conditions"]["hitler_executed"] += 1
        else:
            results["fascist_wins"] += 1
            if env.fasc_policies >= 6:
                results["win_conditions"]["fasc_6_policies"] += 1
            else:
                results["win_conditions"]["hitler_chancellor"] += 1
    
    return results

cfr_self_results = run_cfr_games(cfr_agent, NUM_GAMES)
print_results(cfr_self_results, NUM_GAMES, "CFR vs CFR (Self-Play)")

## Summary Comparison

In [None]:
print("=" * 70)
print("SUMMARY: Liberal Win Rates")
print("=" * 70)
print(f"{'Matchup':<40} {'Liberal Win %':>15}")
print("-" * 55)
print(f"{'Random vs Random (baseline)':<40} {random_results['liberal_wins']/NUM_GAMES*100:>14.1f}%")
print(f"{'CFR Liberals vs Random Fascists':<40} {cfr_lib_results['liberal_wins']/NUM_GAMES*100:>14.1f}%")
print(f"{'Random Liberals vs CFR Fascists':<40} {cfr_fasc_results['liberal_wins']/NUM_GAMES*100:>14.1f}%")
print(f"{'CFR vs CFR (self-play)':<40} {cfr_self_results['liberal_wins']/NUM_GAMES*100:>14.1f}%")