# AlphaZero FJSP Validation

Validates the AlphaZero components:
1. Config compiler verification
2. Environment constraint checking
3. GNN forward pass tests
4. MCTS basic tests
5. Side-by-side comparison with random baseline

In [1]:
import sys
import os
import json
from pathlib import Path

# Dynamic search for project root (works on both local and supercomputer)
def find_project_root():
    """Walk up from cwd or notebook dir to find the FJSSP-MCTS-Research root."""
    # Try from current file location first
    candidates = [
        Path(__file__).resolve().parent if '__file__' in dir() else None,
        Path.cwd(),
    ]
    for start in candidates:
        if start is None:
            continue
        p = start
        for _ in range(10):
            if (p / "utils" / "input_schemas.py").exists():
                return str(p)
            p = p.parent
    # Hardcoded fallback for supercomputer
    fallback = "/home/ad.msoe.edu/weinbendera/FJSSP-MCTS-Research"
    if os.path.isdir(fallback):
        return fallback
    return None

project_root = find_project_root()
if project_root is None:
    raise RuntimeError("Could not find project root! Set project_root manually.")
print(f"Project root: {project_root}")

if project_root not in sys.path:
    sys.path.insert(0, project_root)

import numpy as np
import torch

from utils.factory_logic_loader import FactoryLogicLoader
from utils.job_builder import JobBuilder
from utils.input_schemas import ProductRequest
from schedulers.alphazero.env.config_compiler import CompiledConfig
from schedulers.alphazero.env.fjsp_env import FJSPEnv, UNSCHEDULED, IN_PROGRESS, COMPLETED
from schedulers.alphazero.env.graph_builder import GraphBuilder
from schedulers.alphazero.model.gnn import FJSPNet

Project root: c:\Users\weinbendera\Repos\FJSSP Research\FJSSP-MCTS-Research


In [None]:
# Load data
data_path = os.path.join(project_root, 'data', 'Input_JSON_Schedule_Optimization.json')
factory_logic = FactoryLogicLoader.load_from_file(data_path)

with open(data_path, 'r', encoding='utf-8') as f:
    raw_data = json.load(f)
product_requests = [ProductRequest(**pr) for pr in raw_data['product_requests']]

jobs = JobBuilder(factory_logic).build_jobs(product_requests)
config = CompiledConfig.compile(factory_logic, jobs)

print(f"Compiled: {config.num_ops} ops, {config.num_jobs} jobs, {config.num_machines} machines, {config.num_actions} actions")

## Test 1: Environment Random Rollout (100 episodes)

In [None]:
env = FJSPEnv(config)
makespans = []

for ep in range(100):
    env.reset()
    while not env.done:
        legal = env.get_legal_actions()
        legal_indices = np.where(legal)[0]
        action = np.random.choice(legal_indices)
        env.step(action)
    
    assert np.all(env.op_status == COMPLETED), f"Episode {ep}: not all ops completed!"
    makespans.append(env.get_makespan())

print(f"100 random episodes completed successfully!")
print(f"Makespan - Mean: {np.mean(makespans):.1f}, Min: {np.min(makespans)}, Max: {np.max(makespans)}, Std: {np.std(makespans):.1f}")

## Test 2: get_state/set_state Fuzz Test

In [None]:
for trial in range(10):
    env.reset()
    np.random.seed(trial)
    
    # Random warmup
    for _ in range(np.random.randint(5, 30)):
        if env.done: break
        legal = env.get_legal_actions()
        env.step(np.random.choice(np.where(legal)[0]))
    
    if env.done: continue
    saved = env.get_state()
    
    # Record actions and states
    np.random.seed(trial * 100)
    actions, recorded_states = [], []
    for _ in range(15):
        if env.done: break
        legal = env.get_legal_actions()
        a = np.random.choice(np.where(legal)[0])
        actions.append(a)
        env.step(a)
        recorded_states.append(env.get_state())
    
    # Replay from saved state
    env.set_state(saved)
    for i, a in enumerate(actions):
        env.step(a)
        s = env.get_state()
        assert np.array_equal(s.op_status, recorded_states[i].op_status)
        assert np.array_equal(s.machine_busy, recorded_states[i].machine_busy)
        assert s.current_time == recorded_states[i].current_time

print("get_state/set_state fuzz test (10 trials) PASSED!")

## Test 3: Graph Builder + GNN Forward Pass

In [None]:
graph_builder = GraphBuilder(config)
model = FJSPNet(
    config=config,
    op_feature_dim=graph_builder.op_feature_dim,
    machine_feature_dim=graph_builder.machine_feature_dim,
    global_feature_dim=graph_builder.global_feature_dim,
)

env.reset()
graph = graph_builder.build(env)
legal_mask = env.get_legal_actions()
legal_tensor = torch.tensor(legal_mask, dtype=torch.bool)

with torch.no_grad():
    policy, value = model(graph, legal_tensor)

print(f"Policy shape: {policy.shape} (expected: {config.num_actions + 1})")
print(f"Policy sum: {policy.sum().item():.6f} (expected: 1.0)")
print(f"Value: {value.item():.4f} (expected: in [-1, 1])")

# Verify legal mask respected
illegal_probs = policy[~legal_tensor].sum().item()
print(f"Probability on illegal actions: {illegal_probs:.8f} (expected: ~0.0)")

assert abs(policy.sum().item() - 1.0) < 1e-5, "Policy doesn't sum to 1!"
assert -1 <= value.item() <= 1, "Value outside [-1, 1]!"
assert illegal_probs < 1e-6, "Illegal actions have nonzero probability!"
print("\nGNN forward pass test PASSED!")

## Test 4: MCTS Basic Test

In [None]:
from schedulers.alphazero.mcts.mcts import MCTS, MCTSConfig

env.reset()
mcts = MCTS(env, graph_builder, model, MCTSConfig(num_simulations=50))

state = env.get_state()
visit_counts = mcts.search(state)

print(f"Visit counts shape: {visit_counts.shape}")
print(f"Total visits: {visit_counts.sum():.0f}")
print(f"Nonzero actions: {(visit_counts > 0).sum()}")
print(f"Top 5 actions by visits: {np.argsort(visit_counts)[-5:][::-1]}")
print(f"Top 5 visit counts: {np.sort(visit_counts)[-5:][::-1]}")

assert visit_counts.sum() > 0, "No visits!"
# Only legal actions should have visits
legal = env.get_legal_actions()
illegal_visits = visit_counts[~legal].sum()
assert illegal_visits == 0, f"Illegal actions got {illegal_visits} visits!"
print("\nMCTS basic test PASSED!")

## Test 5: Full Episode with MCTS

In [None]:
from schedulers.alphazero.alphazero_scheduler import AlphaZeroScheduler

# Checkpoint saved by training pipeline (relative to project root)
checkpoint_path = os.path.join(project_root, 'schedulers', 'alphazero', 'checkpoints', 'alphazero_iter0010.pt')
if not os.path.exists(checkpoint_path):
    # Try latest available checkpoint
    ckpt_dir = os.path.join(project_root, 'schedulers', 'alphazero', 'checkpoints')
    if os.path.isdir(ckpt_dir):
        ckpts = sorted([f for f in os.listdir(ckpt_dir) if f.endswith('.pt')])
        if ckpts:
            checkpoint_path = os.path.join(ckpt_dir, ckpts[-1])
            print(f"Using latest checkpoint: {ckpts[-1]}")
        else:
            raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}")
    else:
        raise FileNotFoundError(f"Checkpoint dir not found: {ckpt_dir}")

# Load trained model, policy-only for fast inference
scheduler = AlphaZeroScheduler(
    factory_logic=factory_logic,
    product_requests=product_requests,
    model_path=checkpoint_path,
    use_mcts=False,
)

schedule = scheduler.schedule(jobs, energy_sources=[])
print(f"Schedule complete!")
print(f"Operations scheduled: {len(schedule.operations)}")
print(f"Makespan: {schedule.makespan}")

# Visualize
img = schedule.plot_gantt_by_task()
if img:
    display(img)