In [3]:
from rewrite_puzzle.RewritePuzzleGame import RewritePuzzleGame as Game
from rewrite_puzzle.pytorch.NNet import NNetWrapper as nn

# Create game and network
g = Game(start_expr="1 + 2 * 3", goal_expr=7, max_steps=20)
nnet = nn(g)

# # Load the best model
# nnet.load_checkpoint('./temp/rewrite_puzzle/', 'best.pth.tar')

# Or load a specific iteration
nnet.load_checkpoint('./temp/rewrite_puzzle/', 'checkpoint_50.pth.tar')

TypeError: exceptions must derive from BaseException

In [None]:
"""
Simple script to visualize what the model is doing at each step.
Shows the model's predictions, action probabilities, and decisions.
"""

import numpy as np
from rewrite_puzzle.RewritePuzzleGame import RewritePuzzleGame
from rewrite_puzzle.pytorch.NNet import NNetWrapper
from MCTS import MCTS
from utils import dotdict

# Load the model
print("Loading model...")
game = RewritePuzzleGame(start_expr="(1 + 2)", goal_expr="(2 + 1)", max_steps=5)
nnet = NNetWrapper(game)
nnet.load_checkpoint('./temp/rewrite_puzzle/', 'temp.pth.tar')

# Setup MCTS
args = dotdict({
    'numMCTSSims': 25,
    'cpuct': 1,
})
mcts = MCTS(game, nnet, args)

# Start a game
print("\n" + "="*60)
print("Starting game visualization")
print("="*60)
board = game.getInitBoard()
step = 0
max_steps = 20

while step < max_steps:
    step += 1
    print(f"\n--- Step {step} ---")
    
    # Decode current state
    board_obj = game._decode_state(board)
    current_expr = str(board_obj.current_expr)
    goal_expr = str(board_obj.goal_expr)
    current_value = board_obj.current_expr.evaluate()
    
    print(f"Current expression: {current_expr}")
    print(f"Goal: {goal_expr}")
    print(f"Steps taken: {board_obj.steps_taken}/{board_obj.max_steps}")
    
    # Get raw neural network prediction (show even if solved)
    canonical_board = game.getCanonicalForm(board, 1)
    pi_raw, v = nnet.predict(canonical_board)
    
    # Extract scalar value (handle numpy array/scalar)
    v_scalar = float(np.array(v).item() if isinstance(v, np.ndarray) else v)
    print(f"Model's value prediction (win probability): {v_scalar:.4f}")
    
    # Check if solved
    if board_obj.is_solved():
        print("✓ SOLVED!")
        # Still show what moves would be available (even though game is solved)
        valid_actions = board_obj.get_all_valid_actions()
        if len(valid_actions) > 0:
            print(f"\nNote: {len(valid_actions)} valid moves were available, but game is already solved.")
        break
    
    # Show raw policy (before MCTS) for top actions
    valids_raw = game.getValidMoves(board, 1)
    pi_raw_masked = pi_raw * valids_raw
    if np.sum(pi_raw_masked) > 0:
        pi_raw_masked = pi_raw_masked / np.sum(pi_raw_masked)  # Normalize
    top_raw_idx = np.argmax(pi_raw_masked)
    print(f"Raw policy (before MCTS) top action probability: {pi_raw_masked[top_raw_idx]:.4f}")
    
    # Get valid moves
    valids = game.getValidMoves(board, 1)
    valid_actions = board_obj.get_all_valid_actions()
    
    print(f"\nValid moves ({len(valid_actions)}):")
    max_positions = game.max_expr_length // 2
    
    # Reset MCTS tree for clean predictions
    mcts = MCTS(game, nnet, args)
    
    # Get MCTS action probabilities
    action_probs = mcts.getActionProb(canonical_board, temp=0)
    
    # Show top moves
    top_n = min(5, len(valid_actions))
    action_scores = []
    
    for rule_idx, path in valid_actions:
        position_idx = len(path) % max_positions
        action = rule_idx * max_positions + position_idx
        if action < len(action_probs):
            prob = action_probs[action]
            # Get rule name safely
            try:
                if hasattr(board_obj, 'rules') and rule_idx < len(board_obj.rules):
                    rule_name = board_obj.rules[rule_idx].name
                else:
                    rule_name = f"Rule{rule_idx}"
            except (AttributeError, IndexError):
                rule_name = f"Rule{rule_idx}"
            action_scores.append((prob, rule_idx, path, rule_name, action))
    
    # Sort by probability
    action_scores.sort(reverse=True, key=lambda x: x[0])
    
    print(f"\nTop {top_n} moves (by model preference):")
    for i, (prob, rule_idx, path, rule_name, action) in enumerate(action_scores[:top_n]):
        print(f"  {i+1}. {rule_name:20s} (path: {path}) - Probability: {prob:.4f}")
    
    # Choose best action
    best_action_idx = np.argmax(action_probs)
    print(f"\nModel's chosen action: {best_action_idx} (probability: {action_probs[best_action_idx]:.4f})")
    
    # Apply the action
    board, _ = game.getNextState(board, 1, best_action_idx)
    
    # Check if game ended
    result = game.getGameEnded(board, 1)
    if result != 0:
        if result == 1:
            print("\n✓ Game won!")
        else:
            print("\n✗ Game lost (max steps reached)")
        break

print("\n" + "="*60)
print("Game finished")
print("="*60)

Loading model...

Starting game visualization

--- Step 1 ---
Current expression: (1 + 2)
Goal: (2 + 1)
Steps taken: 0/5
Model's value prediction (win probability): 0.0006
Raw policy (before MCTS) top action probability: 0.5062

Valid moves (2):

Top 2 moves (by model preference):
  1. eval_add_leaves      (path: []) - Probability: 1.0000
  2. commute_add          (path: []) - Probability: 0.0000

Model's chosen action: 500 (probability: 1.0000)

--- Step 2 ---
Current expression: 3
Goal: (2 + 1)
Steps taken: 1/5
Model's value prediction (win probability): 0.0005
Raw policy (before MCTS) top action probability: 1.0000

Valid moves (0):

Top 0 moves (by model preference):

Model's chosen action: 699 (probability: 1.0000)

--- Step 3 ---
Current expression: 3
Goal: (2 + 1)
Steps taken: 2/5
Model's value prediction (win probability): 0.0005
Raw policy (before MCTS) top action probability: 1.0000

Valid moves (0):

Top 0 moves (by model preference):

Model's chosen action: 699 (probability