# Stochastic AlphaZero Comparison

This notebook compares three approaches on the Game of Pig:

1. **Standard AlphaZero** - Treats dice rolls as part of the environment (hidden stochasticity)
2. **Stochastic AlphaZero** - Explicitly models chance nodes with expectimax in MCTS
3. **Hold20 Baseline** - Simple heuristic: hold when turn total >= 20

We train until both AlphaZero variants beat Hold20 and compare:
- Win rate per iteration
- Sample efficiency (win rate / total simulations)

## Step 1: Install Julia (Colab only)

If running on Google Colab, first install Julia by running this cell:

In [None]:
# For Google Colab - Install Julia
# Uncomment and run this cell FIRST if on Colab

# %%shell
# set -e
# curl -fsSL https://install.julialang.org | sh -s -- --yes
# pip install -q julia
# julia -e 'using Pkg; Pkg.add("IJulia")'
# echo "Restart runtime after this cell completes!"

## Step 2: Install AlphaZero.jl Package

Run the cell below to install dependencies and AlphaZero. This takes ~5-8 minutes the first time due to Julia compilation.

**Note**: Plots.jl is now optional - AlphaZero will run without it. We install it separately for visualization.

In [None]:
# Install AlphaZero.jl (Plots is now optional - no OpenGL required!)
using Pkg

# Install the package (this takes ~5-8 min first time for compilation)
println("Installing AlphaZero.jl from stochastic-mcts branch...")
println("This may take several minutes for compilation...")
Pkg.add(url="https://github.com/sile16/AlphaZero.jl", rev="stochastic-mcts")

# Install Plots for visualization (optional but recommended)
println("Installing Plots for visualization...")
ENV["GKSwstype"] = "100"  # Headless mode for Colab
Pkg.add("Plots")

println("Installation complete!")

In [None]:
# Load packages
println("Loading AlphaZero...")
using AlphaZero
using AlphaZero.Examples: Pig
using Random
using Statistics
using Printf

println("Loading Plots...")
using Plots
gr()  # Use GR backend (works in notebooks)

println("All packages loaded successfully!")

## Define Deterministic Pig Game

This version hides the dice roll inside `play!` - standard AlphaZero won't know about the stochastic nature.

In [None]:
module DeterministicPig

import AlphaZero.GI
using Random

const TARGET_SCORE = 100
const ROLL = 1
const HOLD = 2

const Player = Bool
const WHITE = true
const BLACK = false

# State without awaiting_dice - dice is rolled immediately
const State = @NamedTuple{
    p1_score::Int,
    p2_score::Int,
    turn_total::Int,
    curplayer::Player
}

const INITIAL_STATE = State((0, 0, 0, WHITE))

struct GameSpec <: GI.AbstractGameSpec end

GI.two_players(::GameSpec) = true
GI.actions(::GameSpec) = [ROLL, HOLD]

# NO chance outcomes - this is "deterministic" from MCTS perspective
GI.num_chance_outcomes(::GameSpec) = 0

function GI.vectorize_state(::GameSpec, state)
    p1_norm = Float32(state.p1_score / TARGET_SCORE)
    p2_norm = Float32(state.p2_score / TARGET_SCORE)
    turn_norm = Float32(state.turn_total / TARGET_SCORE)
    curplayer = state.curplayer == WHITE ? 1f0 : 0f0
    return Float32[p1_norm, p2_norm, turn_norm, curplayer]
end

mutable struct GameEnv <: GI.AbstractGameEnv
    state::State
    rng::AbstractRNG
end

GI.spec(::GameEnv) = GameSpec()
GI.current_state(g::GameEnv) = g.state
GI.set_state!(g::GameEnv, s) = (g.state = s)
GI.white_playing(g::GameEnv) = g.state.curplayer == WHITE

function GI.init(::GameSpec)
    return GameEnv(INITIAL_STATE, Random.default_rng())
end

function GI.init(::GameSpec, state)
    return GameEnv(state, Random.default_rng())
end

function GI.game_terminated(g::GameEnv)
    return g.state.p1_score >= TARGET_SCORE || g.state.p2_score >= TARGET_SCORE
end

function GI.white_reward(g::GameEnv)
    if g.state.p1_score >= TARGET_SCORE
        return 1.0
    elseif g.state.p2_score >= TARGET_SCORE
        return -1.0
    else
        return 0.0
    end
end

# NOT a chance node - stochasticity is hidden
GI.is_chance_node(g::GameEnv) = false

function GI.actions_mask(g::GameEnv)
    return [true, true]  # Both actions always available
end

function GI.play!(g::GameEnv, action)
    s = g.state
    
    if action == ROLL
        # Dice roll happens INSIDE play! - hidden from MCTS
        die_face = rand(g.rng, 1:6)
        
        if die_face == 1
            # Bust - lose turn total, switch players
            g.state = State((s.p1_score, s.p2_score, 0, !s.curplayer))
        else
            # Add to turn total
            g.state = State((s.p1_score, s.p2_score, s.turn_total + die_face, s.curplayer))
        end
    elseif action == HOLD
        # Add turn total to score, switch players
        if s.curplayer == WHITE
            g.state = State((s.p1_score + s.turn_total, s.p2_score, 0, BLACK))
        else
            g.state = State((s.p1_score, s.p2_score + s.turn_total, 0, WHITE))
        end
    end
end

function GI.heuristic_value(g::GameEnv)
    s = g.state
    if s.curplayer == WHITE
        return (s.p1_score - s.p2_score + s.turn_total) / TARGET_SCORE
    else
        return (s.p2_score - s.p1_score + s.turn_total) / TARGET_SCORE
    end
end

end # module DeterministicPig

println("DeterministicPig module defined.")

## Hold20 Player (Baseline)

In [None]:
# Hold20 player for both game versions
struct Hold20Player <: AbstractPlayer
    threshold::Int
end

Hold20Player() = Hold20Player(20)

function AlphaZero.think(p::Hold20Player, game)
    s = GI.current_state(game)
    turn_total = hasproperty(s, :turn_total) ? s.turn_total : s[3]
    
    if turn_total >= p.threshold
        π = [0.0, 1.0]  # Hold
    else
        π = [1.0, 0.0]  # Roll
    end
    return [1, 2], π
end

AlphaZero.reset!(::Hold20Player) = nothing

println("Hold20Player defined.")

## Evaluation Function

In [None]:
function evaluate_vs_hold20(gspec, player, num_games=100)
    """Evaluate a player against Hold20, returns win rate."""
    hold20 = Hold20Player()
    wins = 0
    
    for i in 1:num_games
        # Alternate who plays white
        if i % 2 == 1
            trace = play_game(gspec, TwoPlayers(player, hold20))
            final_game = GI.init(gspec, trace.states[end])
            if GI.white_reward(final_game) > 0
                wins += 1
            end
        else
            trace = play_game(gspec, TwoPlayers(hold20, player))
            final_game = GI.init(gspec, trace.states[end])
            if GI.white_reward(final_game) < 0
                wins += 1
            end
        end
    end
    
    return wins / num_games
end

println("Evaluation function defined.")

## Training Configuration

In [None]:
# Light parameters for faster training (adjust for your compute)
const MCTS_ITERS = 50          # MCTS iterations per move
const SELF_PLAY_GAMES = 50     # Games per iteration
const EVAL_GAMES = 50          # Games for evaluation
const MAX_ITERS = 20           # Max training iterations
const WIN_THRESHOLD = 0.55     # Win rate to consider "beating" Hold20

# Network parameters
const NET_WIDTH = 64
const NET_DEPTH = 4

println("Configuration:")
println("  MCTS iterations/turn: $MCTS_ITERS")
println("  Self-play games/iter: $SELF_PLAY_GAMES")
println("  Eval games: $EVAL_GAMES")
println("  Max iterations: $MAX_ITERS")
println("  Win threshold: $(WIN_THRESHOLD*100)%")

## Training Loop

In [None]:
function train_and_evaluate(gspec, name; max_iters=MAX_ITERS, use_gpu=false)
    """Train AlphaZero and evaluate against Hold20 each iteration."""
    
    println("\n" * "="^60)
    println("Training: $name")
    println("="^60)
    
    # Results storage
    results = Dict(
        :win_rates => Float64[],
        :total_sims => Int[],
        :iterations => Int[]
    )
    
    # Create network (CPU only for compatibility)
    netparams = NetLib.SimpleNetHP(
        width=NET_WIDTH,
        depth_common=NET_DEPTH,
        use_batch_norm=true,
        batch_norm_momentum=1.0
    )
    nn = NetLib.SimpleNet(gspec, netparams)
    
    total_simulations = 0
    
    for iter in 1:max_iters
        println("\nIteration $iter/$max_iters")
        flush(stdout)
        
        # Self-play
        print("  Self-play ($SELF_PLAY_GAMES games)... ")
        flush(stdout)
        mcts_env = MCTS.Env(gspec, nn, cpuct=1.0, noise_ϵ=0.25, noise_α=1.0)
        player = MctsPlayer(mcts_env, niters=MCTS_ITERS, τ=ConstSchedule(1.0))
        
        for g in 1:SELF_PLAY_GAMES
            trace = play_game(gspec, TwoPlayers(player, player))
            total_simulations += length(trace) * MCTS_ITERS
        end
        println("done")
        flush(stdout)
        
        # Evaluation
        print("  Evaluating vs Hold20 ($EVAL_GAMES games)... ")
        flush(stdout)
        eval_player = MctsPlayer(mcts_env, niters=MCTS_ITERS, τ=ConstSchedule(0.2))
        win_rate = evaluate_vs_hold20(gspec, eval_player, EVAL_GAMES)
        println(@sprintf("%.1f%% win rate", win_rate * 100))
        flush(stdout)
        
        # Record results
        push!(results[:win_rates], win_rate)
        push!(results[:total_sims], total_simulations)
        push!(results[:iterations], iter)
        
        # Check if we beat Hold20
        if win_rate >= WIN_THRESHOLD
            println("\n✓ Beat Hold20 at iteration $iter!")
            break
        end
    end
    
    return results
end

println("Training function defined.")

## Run Comparison

In [None]:
# Train Standard AlphaZero (hidden stochasticity)
println("Training Standard AlphaZero (deterministic MCTS)...")
det_gspec = DeterministicPig.GameSpec()
results_standard = train_and_evaluate(det_gspec, "Standard AlphaZero")

In [None]:
# Train Stochastic AlphaZero (explicit chance nodes)
println("Training Stochastic AlphaZero (expectimax MCTS)...")
stoch_gspec = Pig.GameSpec()
results_stochastic = train_and_evaluate(stoch_gspec, "Stochastic AlphaZero")

## Plot Results

In [None]:
# Plot 1: Win Rate vs Iteration
p1 = plot(
    title="Win Rate vs Hold20 by Iteration",
    xlabel="Training Iteration",
    ylabel="Win Rate (%)",
    legend=:bottomright,
    ylims=(0, 100)
)

plot!(p1, results_standard[:iterations], results_standard[:win_rates] .* 100,
    label="Standard AlphaZero", marker=:circle, linewidth=2)
plot!(p1, results_stochastic[:iterations], results_stochastic[:win_rates] .* 100,
    label="Stochastic AlphaZero", marker=:square, linewidth=2)
hline!(p1, [50], label="Break-even", linestyle=:dash, color=:gray)
hline!(p1, [WIN_THRESHOLD * 100], label="Win threshold", linestyle=:dot, color=:green)

display(p1)

In [None]:
# Plot 2: Win Rate vs Total Simulations (Sample Efficiency)
p2 = plot(
    title="Sample Efficiency: Win Rate vs Total MCTS Simulations",
    xlabel="Total MCTS Simulations",
    ylabel="Win Rate (%)",
    legend=:bottomright,
    ylims=(0, 100)
)

plot!(p2, results_standard[:total_sims], results_standard[:win_rates] .* 100,
    label="Standard AlphaZero", marker=:circle, linewidth=2)
plot!(p2, results_stochastic[:total_sims], results_stochastic[:win_rates] .* 100,
    label="Stochastic AlphaZero", marker=:square, linewidth=2)
hline!(p2, [50], label="Break-even", linestyle=:dash, color=:gray)

display(p2)

In [None]:
# Plot 3: Efficiency metric (win_rate / simulations)
p3 = plot(
    title="Learning Efficiency Over Training",
    xlabel="Training Iteration",
    ylabel="Win Rate / Million Simulations",
    legend=:topright
)

efficiency_std = results_standard[:win_rates] ./ (results_standard[:total_sims] ./ 1e6)
efficiency_stoch = results_stochastic[:win_rates] ./ (results_stochastic[:total_sims] ./ 1e6)

plot!(p3, results_standard[:iterations], efficiency_std,
    label="Standard AlphaZero", marker=:circle, linewidth=2)
plot!(p3, results_stochastic[:iterations], efficiency_stoch,
    label="Stochastic AlphaZero", marker=:square, linewidth=2)

display(p3)

In [None]:
# Combined plot
combined = plot(p1, p2, p3, layout=(1, 3), size=(1400, 400))
savefig(combined, "stochastic_comparison_results.png")
println("Plot saved to: stochastic_comparison_results.png")
display(combined)

## Summary Statistics

In [None]:
println("\n" * "="^60)
println("SUMMARY")
println("="^60)

println("\nStandard AlphaZero (hidden stochasticity):")
println("  Final win rate: $(round(results_standard[:win_rates][end] * 100, digits=1))%")
println("  Total simulations: $(results_standard[:total_sims][end])")
println("  Iterations trained: $(length(results_standard[:iterations]))")

println("\nStochastic AlphaZero (explicit chance nodes):")
println("  Final win rate: $(round(results_stochastic[:win_rates][end] * 100, digits=1))%")
println("  Total simulations: $(results_stochastic[:total_sims][end])")
println("  Iterations trained: $(length(results_stochastic[:iterations]))")

# Determine winner
std_beat = findfirst(x -> x >= WIN_THRESHOLD, results_standard[:win_rates])
stoch_beat = findfirst(x -> x >= WIN_THRESHOLD, results_stochastic[:win_rates])

println("\nIterations to beat Hold20 ($(WIN_THRESHOLD*100)% threshold):")
if std_beat !== nothing
    println("  Standard: Iteration $std_beat ($(results_standard[:total_sims][std_beat]) sims)")
else
    println("  Standard: Did not beat Hold20")
end
if stoch_beat !== nothing
    println("  Stochastic: Iteration $stoch_beat ($(results_stochastic[:total_sims][stoch_beat]) sims)")
else
    println("  Stochastic: Did not beat Hold20")
end

# Efficiency comparison
println("\nEfficiency (final win_rate / million sims):")
eff_std = results_standard[:win_rates][end] / (results_standard[:total_sims][end] / 1e6)
eff_stoch = results_stochastic[:win_rates][end] / (results_stochastic[:total_sims][end] / 1e6)
println("  Standard: $(round(eff_std, digits=4))")
println("  Stochastic: $(round(eff_stoch, digits=4))")

## Expected Results

Based on theory, **Stochastic AlphaZero should outperform Standard AlphaZero** on Pig because:

1. **Better value estimates**: Expectimax computes the true expected value over dice outcomes
2. **Correct exploration**: MCTS explores all dice possibilities, not random samples
3. **Training signal**: The network learns values that account for stochasticity

Standard AlphaZero treats dice rolls as hidden environment dynamics, leading to:
- Noisy value estimates (same state can lead to different outcomes)
- Suboptimal exploration (doesn't know which outcomes are possible)
- Confused training signal (high variance in observed rewards)