In [3]:
# 1. Install the latest JAX TPU wheels (forcing a libtpu update)
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

!pip install -U chess

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
!python -m pip uninstall -y matplotlib
!python -m pip install -U --no-cache-dir --force-reinstall "matplotlib>=3.8" matplotlib-inline


Found existing installation: matplotlib 3.10.8
Uninstalling matplotlib-3.10.8:
  Successfully uninstalled matplotlib-3.10.8
[0mCollecting matplotlib>=3.8
  Downloading matplotlib-3.10.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (52 kB)
Collecting matplotlib-inline
  Downloading matplotlib_inline-0.2.1-py3-none-any.whl.metadata (2.3 kB)
Collecting contourpy>=1.0.1 (from matplotlib>=3.8)
  Downloading contourpy-1.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib>=3.8)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib>=3.8)
  Downloading fonttools-4.61.1-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl.metadata (114 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib>=3.8)
  Downloading kiwisolver-1.4.9-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (6.3 kB)
Collect

In [5]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import chess
import chess.svg
import numpy as np
import optax
import math
import time
from einops import rearrange
from IPython.display import display, clear_output, HTML



In [6]:
# ==============================================================================
# 1. OPTIMAL UNI-DIRECTIONAL ARCHITECTURE (COMPUTE-BOUND)
# ==============================================================================
import jax
import jax.numpy as jnp
import flax.linen as nn
from einops import rearrange
import numpy as np

BATCH_SIZE = 128
D_MODEL = 256
N_STATE = 64
TIMESTEPS = 100
NUM_CLASSES = 14  # 0 empty, 1..12 pieces, 13 mask

class UniMambaBlock(nn.Module):
    d_model: int
    n_state: int
    
    @nn.compact
    def __call__(self, x, time_emb):
        B, L, D = x.shape
        x_fused = x + time_emb[:, jnp.newaxis, :]
        
        A_log = self.param('A_log', nn.initializers.normal(0.1), (self.d_model, self.n_state))
        A = -jnp.exp(A_log)
        A_inv = 1.0 / (A + 1e-7) 
        
        B_proj = nn.Dense(self.n_state, dtype=jnp.bfloat16)(x_fused) 
        C_proj = nn.Dense(self.n_state, dtype=jnp.bfloat16)(x_fused)
        delta_proj = nn.Dense(self.d_model, dtype=jnp.bfloat16)(x_fused) 
        
        def scan_fn(h, params):
            b_t, c_t, delta_raw_t, x_t = params
            delta = nn.softplus(delta_raw_t.astype(jnp.float32))
            
            dt_A = jnp.einsum('bd,dn->bdn', delta, A)
            A_bar = jnp.exp(dt_A)
            
            B_bar = jnp.einsum('bdn,dn,bn->bdn', (A_bar - 1.0), A_inv, b_t.astype(jnp.float32))
            h_new = A_bar * h + jnp.einsum('bdn,bd->bdn', B_bar, x_t.astype(jnp.float32))
            y_t = jnp.einsum('bdn,bn->bd', h_new, c_t.astype(jnp.float32))
            return h_new, y_t.astype(jnp.bfloat16)

        Xs_fwd = (
            rearrange(B_proj, 'b l ... -> l b ...'),
            rearrange(C_proj, 'b l ... -> l b ...'),
            rearrange(delta_proj, 'b l ... -> l b ...'),
            rearrange(x_fused, 'b l ... -> l b ...')
        )
        
        h_init = jnp.zeros((B, self.d_model, self.n_state), dtype=jnp.float32)
        _, y_f = jax.lax.scan(scan_fn, h_init, Xs_fwd)
        y_out = rearrange(y_f, 'l b d -> b l d')
        
        return nn.LayerNorm(dtype=jnp.bfloat16)(y_out + x)

class UniMambaChessEngine(nn.Module):
    @nn.compact
    def __call__(self, x_t, t, mask_dest=None): 
        x = nn.Embed(NUM_CLASSES, D_MODEL, dtype=jnp.bfloat16)(x_t)
        
        indices = jnp.arange(64)
        rank_emb = nn.Embed(8, D_MODEL, dtype=jnp.bfloat16)(jnp.tile(indices // 8, (x_t.shape[0], 1)))
        file_emb = nn.Embed(8, D_MODEL, dtype=jnp.bfloat16)(jnp.tile(indices % 8, (x_t.shape[0], 1)))
        
        x = x + rank_emb + file_emb
        x = nn.Dense(D_MODEL, dtype=jnp.bfloat16)(x) 
        
        t_emb = nn.Embed(TIMESTEPS, D_MODEL, dtype=jnp.bfloat16)(t)
        Block = nn.remat(UniMambaBlock)
            
        for _ in range(4): 
            x = Block(D_MODEL, N_STATE)(x, t_emb)
            x = nn.gelu(x)

        # -----------------------
        # Value head (always built)
        # -----------------------
        pooled = jnp.mean(x.astype(jnp.float32), axis=1)  # [B, D]
        v = nn.Dense(128, dtype=jnp.float32, name="value_fc")(pooled)
        v = nn.gelu(v)
        v = nn.Dense(1, dtype=jnp.float32, name="value_out")(v)
        v = jnp.tanh(v).squeeze(-1)  # [B] in [-1, 1]

        # -----------------------
        # MTM head (same params regardless of branch)
        # -----------------------
        mtm_head = nn.Dense(NUM_CLASSES, dtype=jnp.float32, name="mtm_head")

        if mask_dest is not None:
            dest_vec = jnp.sum(x * mask_dest.astype(jnp.bfloat16), axis=1)  # [B, D]
            logits = mtm_head(dest_vec)  # [B, 14]
            return logits, v

        logits = mtm_head(jnp.mean(x.astype(jnp.float32), axis=1))
        return logits, v

inference_model = UniMambaChessEngine()

In [7]:
# ==============================================================================
# CELL 2: DENSE ALGEBRAIC KERNELS (100% XLA FUSION)
# ==============================================================================
import jax
import jax.numpy as jnp
import optax
import numpy as np

def board_to_tensor(board):
    x = np.zeros(64, dtype=np.int32)
    for i in range(64):
        p = board.piece_at(i)
        if p: x[i] = p.piece_type + (0 if p.color else 6)
    return x

@jax.jit
def calculate_mtm_energies(frozen_params, candidate_batch, dest_indices, piece_classes):
    """
    Returns:
      energies: [N]      = -log P(correct_piece_at_dest)
      probs:    [N, 14]  = softmax logits
      values:   [N]      = value head output for candidate states (side-to-move in that candidate)
    """
    num_candidates = candidate_batch.shape[0]
    
    mask_dest = jax.nn.one_hot(dest_indices, 64, dtype=jnp.int32)
    masked_batch = candidate_batch * (1 - mask_dest) + 13 * mask_dest
    t_batch = jnp.full((num_candidates,), 1, dtype=jnp.int32)
    
    mask_dest_expanded = jnp.expand_dims(mask_dest, -1).astype(jnp.float32)
    dest_logits, values = inference_model.apply(
        {'params': frozen_params}, masked_batch, t_batch, mask_dest=mask_dest_expanded
    )
    
    probs = jax.nn.softmax(dest_logits, axis=-1)
    mask_piece = jax.nn.one_hot(piece_classes, 14, dtype=jnp.float32)
    actual_piece_probs = jnp.sum(probs * mask_piece, axis=-1)
    
    energies = -jnp.log(actual_piece_probs + 1e-7)
    return energies, probs, values

VALUE_LOSS_WEIGHT = 0.25

@jax.jit
def rl_mtm_train_step(state, masked_boards, dest_indices, target_classes, value_targets):
    """
    Joint supervised step:
      - MTM: predict moved piece class at destination square
      - Value: predict game result (coarse target) from side-to-move perspective
    """
    def loss_fn(params):
        num_samples = masked_boards.shape[0]
        t_batch = jnp.full((num_samples,), 1, dtype=jnp.int32)
        
        mask_dest = jax.nn.one_hot(dest_indices, 64, dtype=jnp.float32)
        mask_dest_expanded = jnp.expand_dims(mask_dest, -1)
        
        logits, v_pred = inference_model.apply(
            {'params': params}, masked_boards, t_batch, mask_dest=mask_dest_expanded
        )
        
        # MTM loss
        target_one_hot = jax.nn.one_hot(target_classes, 14, dtype=jnp.float32)
        log_probs = jax.nn.log_softmax(logits, axis=-1)
        mtm_loss = -jnp.sum(target_one_hot * log_probs, axis=-1).mean()

        # Value loss
        v_loss = optax.huber_loss(v_pred, value_targets).mean()

        preds = jnp.argmax(logits, axis=-1)
        acc = jnp.mean(preds == target_classes)

        loss = mtm_loss + VALUE_LOSS_WEIGHT * v_loss
        return loss, (mtm_loss, v_loss, acc)

    (loss, (mtm_loss, v_loss, acc)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, mtm_loss, v_loss, acc

In [8]:
# ==============================================================================
# CELL 2B (REPLACE): DONATION + PROFILER ANNOTATIONS (SAFE)
#   - Do NOT donate arg0 (TrainState). Only donate batch arrays.
# ==============================================================================
import jax
from jax import profiler as jprof

_rl_mtm_train_step_impl = rl_mtm_train_step  # keep original implementation

@jax.jit(donate_argnums=(1, 2, 3, 4))  # ✅ donate only batch inputs, NOT state
def rl_mtm_train_step(state, masked_boards, dest_indices, target_classes, value_targets):
    with jprof.TraceAnnotation("train_step", _r=1):
        return _rl_mtm_train_step_impl(state, masked_boards, dest_indices, target_classes, value_targets)


In [9]:
# ==============================================================================
# 3. MCTS NODE STRUCTURE WITH ALPHA-BETA OVERRIDES (FIXED POV)
# ==============================================================================
class MCTSNode:
    def __init__(self, state, parent=None, move=None, prior=0.0):
        self.state = state
        self.parent = parent
        self.move = move

        self.prior = float(prior)
        self.visits = 0
        self.value_sum = 0.0  # value from perspective of side-to-move at THIS node

        self.children = {}
        self.is_expanded = False

        # Proven flags from perspective of side-to-move at THIS node
        self.is_proven_win = False
        self.is_proven_loss = False

        self.xray_probs = None
        self.xray_piece = None
        
    @property
    def q_value(self):
        if self.is_proven_win:
            return 1.0
        if self.is_proven_loss:
            return -1.0
        return (self.value_sum / self.visits) if self.visits > 0 else 0.0

    def ucb_from_parent(self, c_puct):
        """
        Selection score from *parent* perspective.
        Since child.q_value is from child's side-to-move perspective (opponent),
        parent wants to maximize (-child.q_value).
        """
        if self.is_proven_loss:
            # opponent-to-move is proven loss => this move is forced win for parent
            return float('inf')
        if self.is_proven_win:
            # opponent-to-move is proven win => terrible for parent
            return -float('inf')

        parent_visits = max(1, self.parent.visits if self.parent else 1)
        q_parent = -self.q_value
        u = c_puct * self.prior * (math.sqrt(parent_visits) / (1.0 + self.visits))
        return q_parent + u

def update_solved(n: MCTSNode):
    if not n.children:
        return
    # If ANY child is proven_loss (for side-to-move in child), current is proven_win
    if any(ch.is_proven_loss for ch in n.children.values()):
        n.is_proven_win, n.is_proven_loss = True, False
        return
    # If ALL children are proven_win, current is proven_loss
    if all(ch.is_proven_win for ch in n.children.values()):
        n.is_proven_win, n.is_proven_loss = False, True

In [10]:
# ==============================================================================
# 4. THE OMNISCIENT DASHBOARD (UI VISUALIZER) [FIXED SEMANTICS]
# ==============================================================================
PIECE_NAMES = ["Empty", "W_P", "W_N", "W_B", "W_R", "W_Q", "W_K", 
               "B_P", "B_N", "B_B", "B_R", "B_Q", "B_K", "MASK"]

def render_dashboard(root, board, phase_title, sim_status, loss_info=""):
    children = sorted(root.children.values(), key=lambda c: c.visits, reverse=True)
    fill, arrows = {}, []
    
    if children:
        total_visits = sum(c.visits for c in children) + 1e-9
        for child in children:
            opacity = int((child.visits / total_visits) * 255)
            fill[child.move.to_square] = f"#0055ff{opacity:02x}" 
            
        curr, depth = root, 0
        while curr.children and depth < 4:
            kids = list(curr.children.values())
            forced = [c for c in kids if c.is_proven_loss]  # forced win for parent
            if forced:
                best = forced[0]
            else:
                safe = [c for c in kids if not c.is_proven_win] or kids
                best = max(safe, key=lambda c: c.visits)
            arrows.append(chess.svg.Arrow(best.move.from_square, best.move.to_square, color="#28a745"))
            curr = best
            depth += 1

    board_svg = chess.svg.board(
        board, arrows=arrows, fill=fill, size=400,
        lastmove=board.peek() if board.move_stack else None
    )
    
    if children:
        best = children[0]
        root_q = -best.q_value
        eval_score = ((root_q + 1) / 2 * 100)
    else:
        eval_score = 50.0
    
    eval_color = "green" if eval_score > 55 else "red" if eval_score < 45 else "black"
    
    html = f"""
    <div style='display:flex; gap:20px; font-family:sans-serif; max-width:900px; margin:auto; padding:10px; border:1px solid #ddd; border-radius:8px;'>
        <div>{board_svg}</div>
        <div style='flex-grow:1;'>
            <h2 style='margin-top:0; color:#333;'>{phase_title}</h2>
            <div style='background:#f8f9fa; padding:10px; border-radius:5px; margin-bottom:10px;'>
                <b>Status:</b> {sim_status} <br>
                <b>Engine Eval (root POV):</b> <span style='color: {eval_color}; font-weight: bold;'>{eval_score:.1f}% Win Probability</span>
                {f"<br><b>Train Loss:</b> {loss_info}" if loss_info else ""}
            </div>
            <h4 style='border-bottom: 1px solid #ccc; padding-bottom: 5px;'>X-Ray Internal Monologue</h4>
            <div style='max-height: 280px; overflow-y: auto; font-size: 13px;'>
    """
    
    for i, c in enumerate(children[:3]):
        tag = ""
        if c.is_proven_loss:
            tag = "<span style='color: green; font-weight: bold;'>[FORCED WIN MOVE]</span>"
        elif c.is_proven_win:
            tag = "<span style='color: red; font-weight: bold;'>[LOSING MOVE]</span>"

        html += f"<div style='margin-bottom: 12px; padding: 8px; border-left: 3px solid #0055ff; background: #fdfdfd;'>"
        html += f"<b>Line {i+1}: {board.san(c.move)}</b> {tag}<br>"
        html += f"<span style='color: #666;'>Visits: {c.visits} | Q(root): {-c.q_value:.2f} | Prior: {c.prior*100:.1f}%</span><br>"
        
        if c.xray_probs is not None:
            html += f"<div style='margin-top: 4px; font-family: monospace; font-size: 11px; color: #444;'>"
            html += f"<b>[MASK] Prediction (Target Sq):</b><br>"
            top_preds = np.argsort(c.xray_probs)[::-1][:2]
            for idx in top_preds:
                mark = " &lt;-- Actual" if idx == c.xray_piece else ""
                html += f"&nbsp;&nbsp;{PIECE_NAMES[idx]}: {c.xray_probs[idx]*100:.1f}%{mark}<br>"
            if c.xray_piece not in top_preds:
                html += f"&nbsp;&nbsp;... {PIECE_NAMES[c.xray_piece]}: {c.xray_probs[c.xray_piece]*100:.2f}% &lt;-- Actual<br>"
            html += "</div>"
        html += "</div>"
        
    html += "</div></div></div>"
    
    clear_output(wait=True)
    display(HTML(html))

In [11]:
# ==============================================================================
# 5. CORE MCTS EXECUTION ALGORITHM (FIXED POV + SOFT PRIORS + VALUE HEAD)
# ==============================================================================
def run_mcts(
    root,
    params,
    num_simulations=100,
    phase_title="Engine Analysis",
    visualize=False,
    c_puct=2.0,
    temp=3.0,           # >1.0 reduces overconfidence
    prior_mix=0.20,     # mix priors with uniform
    add_root_noise=True,
    dirichlet_alpha=0.3,
    dirichlet_eps=0.25,
):
    for i in range(num_simulations):
        node = root
        path = [node]

        # 1) Selection (parent POV is handled inside ucb_from_parent)
        while node.is_expanded and node.children and not (node.is_proven_win or node.is_proven_loss):
            node = max(node.children.values(), key=lambda c: c.ucb_from_parent(c_puct))
            path.append(node)

        # 2) Terminal handling
        if node.is_proven_win:
            leaf_value = 1.0
        elif node.is_proven_loss:
            leaf_value = -1.0
        elif node.state.is_game_over():
            outcome = node.state.outcome()
            if outcome.winner is None:
                leaf_value = 0.0
            else:
                leaf_value = 1.0 if outcome.winner == node.state.turn else -1.0
            if leaf_value == 1.0:
                node.is_proven_win = True
            elif leaf_value == -1.0:
                node.is_proven_loss = True
        else:
            # 3) Expansion using MTM prior + VALUE HEAD evaluation
            moves = list(node.state.legal_moves)
            candidates, dest_indices, piece_classes = [], [], []
            
            for m in moves:
                node.state.push(m)
                candidates.append(board_to_tensor(node.state))
                dest_indices.append(m.to_square)
                p = node.state.piece_at(m.to_square)
                piece_classes.append(p.piece_type + (0 if p.color else 6))
                node.state.pop()
                
            if candidates:
                energies, full_probs, child_values = calculate_mtm_energies(
                    params,
                    jnp.array(candidates, dtype=jnp.int32),
                    jnp.array(dest_indices, dtype=jnp.int32),
                    jnp.array(piece_classes, dtype=jnp.int32),
                )
                energies = np.array(energies)
                full_probs = np.array(full_probs)
                child_values = np.array(child_values)  # from child POV (opponent-to-move)

                # Soft priors: exp(-E/temp), normalize, then mix with uniform
                logits = -energies / float(temp)
                logits = logits - np.max(logits)
                priors = np.exp(logits)
                priors = priors / (np.sum(priors) + 1e-12)
                priors = (1.0 - prior_mix) * priors + prior_mix * (1.0 / len(priors))

                # Root Dirichlet noise to avoid tunnel vision
                if add_root_noise and node is root:
                    noise = np.random.dirichlet([dirichlet_alpha] * len(priors))
                    priors = (1.0 - dirichlet_eps) * priors + dirichlet_eps * noise

                for idx, (m, p) in enumerate(zip(moves, priors)):
                    new_board = node.state.copy()
                    new_board.push(m)
                    child = MCTSNode(new_board, parent=node, move=m, prior=float(p))
                    child.xray_probs = full_probs[idx]
                    child.xray_piece = piece_classes[idx]
                    node.children[m] = child

                node.is_expanded = True

                # Leaf value from node POV:
                # child_values are from opponent POV => node POV is negative of that
                leaf_value = float(-np.sum(priors * child_values))
                leaf_value = float(np.clip(leaf_value, -1.0, 1.0))
            else:
                leaf_value = 0.0

        # 4) Backprop with sign flip + solved propagation
        value = leaf_value
        for n in reversed(path):
            n.visits += 1
            n.value_sum += value
            update_solved(n)
            value = -value

        # 5) UI Hook
        if visualize and i % 10 == 0:
            render_dashboard(root, root.state, phase_title, f"Simulating {i}/{num_simulations}...")

    if visualize:
        render_dashboard(root, root.state, phase_title, "Analysis Complete.")

In [12]:
# ==============================================================================
# CELL 5.5 (REPLACE): EXPERT DATASET LOADER (MULTI-PGN, POST-MOVE MTM, VALUE POV)
#   - Creates samples from the board *after* each move (matches calculate_mtm_energies)
#   - Masks move.to_square in that post-move board
#   - Target class is the actual piece now occupying that square (handles promotion)
#   - Value target matches the side-to-move of the post-move board
# ==============================================================================
import os
import numpy as np
import jax.numpy as jnp
import chess
import chess.pgn

def _result_to_winner(result_str: str):
    if result_str == "1-0": return chess.WHITE
    if result_str == "0-1": return chess.BLACK
    return None

def make_multi_pgn_dataloader(
    pgn_files,
    batches_per_epoch: int,
    batch_size: int = 128,
    seed: int = 0,
    shuffle_files_each_epoch: bool = True,
):
    missing = [p for p in pgn_files if not os.path.exists(p)]
    if missing:
        print("❌ Missing PGN files:")
        for m in missing: print(" -", m)
        raise FileNotFoundError("Some PGN files are missing.")

    def dataloader_factory(epoch: int = 0):
        order = list(pgn_files)
        if shuffle_files_each_epoch:
            rng = np.random.default_rng(seed + epoch)
            rng.shuffle(order)

        masked_boards, dest_indices, target_classes, value_targets = [], [], [], []
        yielded = 0

        for path in order:
            with open(path, "r", encoding="utf-8", errors="ignore") as pgn:
                while yielded < batches_per_epoch:
                    game = chess.pgn.read_game(pgn)
                    if game is None:
                        break

                    winner = _result_to_winner(game.headers.get("Result", "*"))
                    board = game.board()

                    for move in game.mainline_moves():
                        # --- advance to post-move board (this is what MCTS evaluates) ---
                        board.push(move)

                        dest_sq = move.to_square
                        piece = board.piece_at(dest_sq)
                        if piece is None:
                            # should not happen, but guard anyway
                            continue

                        piece_class = piece.piece_type + (0 if piece.color else 6)

                        # value label from POV of *side-to-move in this post-move state*
                        if winner is None:
                            v_t = 0.0
                        else:
                            v_t = 1.0 if board.turn == winner else -1.0

                        bt = board_to_tensor(board)
                        bt[dest_sq] = 13  # MASK the destination in the post-move board

                        masked_boards.append(bt)
                        dest_indices.append(dest_sq)
                        target_classes.append(piece_class)
                        value_targets.append(v_t)

                        if len(masked_boards) == batch_size:
                            yield (
                                jnp.array(masked_boards, dtype=jnp.int32),
                                jnp.array(dest_indices, dtype=jnp.int32),
                                jnp.array(target_classes, dtype=jnp.int32),
                                jnp.array(value_targets, dtype=jnp.float32),
                            )
                            yielded += 1
                            masked_boards, dest_indices, target_classes, value_targets = [], [], [], []
                            if yielded >= batches_per_epoch:
                                return

        if yielded < batches_per_epoch:
            print(f"⚠️ Epoch ended early: yielded {yielded}/{batches_per_epoch} batches.")

    return dataloader_factory


In [13]:
# ==============================================================================
# CELL 6 (REPLACE): TRAINING + HW/SW PROFILING + PLOTS (FIXED WARMUP)
# ==============================================================================
import time
import re
import os
import numpy as np
import jax
import jax.numpy as jnp
from jax import profiler as jprof
from IPython.display import clear_output, display, HTML

MAX_TRAIN_BATCH = 128

def _try_device_memory_stats():
    try:
        return jax.devices()[0].memory_stats()
    except Exception as e1:
        try:
            backend = jax.lib.xla_bridge.get_backend()
            return backend.memory_stats()
        except Exception as e2:
            return {"unavailable": True, "err1": repr(e1), "err2": repr(e2)}

def _hlo_memory_space_report(compiled):
    try:
        hlo = compiled.compiler_ir(dialect="hlo").as_text()
        ms = re.findall(r"memory_space=(\d+)", hlo)
        counts = {}
        for x in ms:
            k = int(x)
            counts[k] = counts.get(k, 0) + 1
        async_copies = len(re.findall(r"copy-start|copy-done|async-copy", hlo))
        return {"memory_space_counts": counts, "async_copy_ops": async_copies}
    except Exception as e:
        return {"error": repr(e)}

def _xla_mem_analysis(compiled):
    try:
        mem = compiled.memory_analysis()
        return {
            "temp_mb": mem.temp_size_in_bytes / (1024**2),
            "args_mb": mem.argument_size_in_bytes / (1024**2),
            "out_mb":  mem.output_size_in_bytes / (1024**2),
        }
    except Exception as e:
        return {"error": repr(e)}

def _safe_plots(history):
    try:
        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(history["loss_epoch"], label="loss")
        plt.plot(history["mtm_loss_epoch"], label="mtm_loss")
        plt.plot(history["v_loss_epoch"], label="v_loss")
        plt.legend(); plt.title("Loss per epoch"); plt.xlabel("epoch"); plt.ylabel("loss"); plt.show()

        plt.figure()
        plt.plot(history["acc_epoch"], label="MTM Acc (%)")
        plt.legend(); plt.title("Accuracy per epoch"); plt.xlabel("epoch"); plt.ylabel("%"); plt.show()

        plt.figure()
        plt.plot(history["sps_epoch"], label="SPS")
        plt.legend(); plt.title("Throughput per epoch"); plt.xlabel("epoch"); plt.ylabel("samples/s"); plt.show()

        plt.figure()
        plt.plot(history["p50_ms_epoch"], label="p50 ms")
        plt.plot(history["p90_ms_epoch"], label="p90 ms")
        plt.plot(history["p99_ms_epoch"], label="p99 ms")
        plt.legend(); plt.title("Device step latency (ms)"); plt.xlabel("epoch"); plt.ylabel("ms"); plt.show()
    except Exception as e:
        print("⚠️ matplotlib unavailable/broken; skipping plots.")
        print("Reason:", repr(e))

def _render_perf_panel(ep, epochs, step, steps_in_epoch, metrics, hw, memstats):
    html = f"""
    <div style="font-family:system-ui, -apple-system, Segoe UI, Roboto, sans-serif; padding:12px; border:1px solid #ddd; border-radius:10px;">
      <h3 style="margin:0 0 8px 0;">Training Perf (SW + HW)</h3>
      <div style="display:grid; grid-template-columns: 1fr 1fr; gap:12px;">
        <div style="background:#f6f8fa; border-radius:10px; padding:10px;">
          <div><b>Epoch</b> {ep}/{epochs} &nbsp; <b>Step</b> {step}/{steps_in_epoch}</div>
          <div style="margin-top:6px;">
            <b>Loss</b> {metrics["loss"]:.4f} &nbsp;
            <b>MTM</b> {metrics["mtm_loss"]:.4f} &nbsp;
            <b>V</b> {metrics["v_loss"]:.4f} &nbsp;
            <b>Acc</b> {metrics["acc"]:.1f}%
          </div>
          <div style="margin-top:6px;">
            <b>SPS</b> {metrics["sps"]:,.0f} &nbsp;
            <b>Fetch</b> {metrics["batch_fetch_ms"]:.2f}ms &nbsp;
            <b>Device</b> {metrics["device_ms"]:.2f}ms &nbsp;
            <b>E2E</b> {metrics["e2e_ms"]:.2f}ms
          </div>
          <div style="margin-top:6px;">
            <b>Latency p50/p90/p99</b>: {metrics["p50_ms"]:.2f}/{metrics["p90_ms"]:.2f}/{metrics["p99_ms"]:.2f} ms
          </div>
        </div>
        <div style="background:#f6f8fa; border-radius:10px; padding:10px;">
          <div><b>Backend</b>: {hw["backend"]}</div>
          <div style="margin-top:6px;"><b>Devices</b>:<br><span style="font-size:12px;">{hw["devices"]}</span></div>
          <div style="margin-top:6px;"><b>Compile</b>: {hw["compile_s"]:.2f}s</div>
          <div style="margin-top:6px;"><b>XLA Mem (MB)</b>: temp={hw["xla_temp_mb"]:.1f} args={hw["xla_args_mb"]:.1f} out={hw["xla_out_mb"]:.1f}</div>
          <div style="margin-top:6px;"><b>HLO memory_space</b>: {hw["hlo_ms"]}</div>
          <div style="margin-top:6px;"><b>Async copies</b>: {hw["hlo_async"]}</div>
        </div>
      </div>
      <div style="margin-top:10px; background:#fff; border:1px solid #eee; border-radius:10px; padding:10px;">
        <b>Device memory stats</b> (if available):<br>
        <pre style="margin:6px 0 0 0; white-space:pre-wrap; font-size:12px;">{memstats}</pre>
      </div>
    </div>
    """
    clear_output(wait=True)
    display(HTML(html))

def train_supervised_phase(
    train_state_obj,
    dataloader_factory,
    epochs=5,
    batches_per_epoch=300,
    ui_every=10,
    trace_dir=None,
    trace_steps=30,
):
    dummy_boards = jnp.zeros((MAX_TRAIN_BATCH, 64), dtype=jnp.int32)
    dummy_dests  = jnp.zeros((MAX_TRAIN_BATCH,), dtype=jnp.int32)
    dummy_targs  = jnp.zeros((MAX_TRAIN_BATCH,), dtype=jnp.int32)
    dummy_vals   = jnp.zeros((MAX_TRAIN_BATCH,), dtype=jnp.float32)

    compile_t0 = time.time()
    compiled = rl_mtm_train_step.lower(train_state_obj, dummy_boards, dummy_dests, dummy_targs, dummy_vals).compile()
    compile_s = time.time() - compile_t0

    xla_mem = _xla_mem_analysis(compiled)
    hlo_rep = _hlo_memory_space_report(compiled)

    devs = jax.devices()
    backend = jax.default_backend()
    dev_str = ", ".join([f"{d.platform}:{getattr(d, 'device_kind', 'device')}" for d in devs])

    hw = {
        "backend": backend,
        "devices": dev_str,
        "compile_s": float(compile_s),
        "xla_temp_mb": float(xla_mem.get("temp_mb", float("nan"))),
        "xla_args_mb": float(xla_mem.get("args_mb", float("nan"))),
        "xla_out_mb":  float(xla_mem.get("out_mb", float("nan"))),
        "hlo_ms": str(hlo_rep.get("memory_space_counts", {})),
        "hlo_async": int(hlo_rep.get("async_copy_ops", 0)) if "async_copy_ops" in hlo_rep else -1,
    }

    # ✅ warmup must rebind state (safe even without donating state)
    train_state_obj, _, _, _, _ = rl_mtm_train_step(
        train_state_obj, dummy_boards, dummy_dests, dummy_targs, dummy_vals
    )
    jax.block_until_ready(train_state_obj.params)

    tracing = False
    traced = 0
    if trace_dir:
        os.makedirs(trace_dir, exist_ok=True)
        jprof.start_trace(trace_dir)
        tracing = True

    history = {
        "epoch": [],
        "loss_epoch": [], "mtm_loss_epoch": [], "v_loss_epoch": [], "acc_epoch": [],
        "sps_epoch": [],
        "p50_ms_epoch": [], "p90_ms_epoch": [], "p99_ms_epoch": [],
        "epoch_time_s": [],
        "hardware": hw,
    }

    for ep in range(1, epochs + 1):
        loss_list, mtm_list, v_list, acc_list = [], [], [], []
        device_ms_list = []

        epoch_t0 = time.time()
        step = 0

        dl = dataloader_factory(ep - 1)
        while step < batches_per_epoch:
            step += 1

            t_fetch0 = time.time()
            try:
                batch = next(dl)
            except StopIteration:
                break
            t_fetch1 = time.time()

            m_b, d_i, t_c, v_t = batch

            t_dev0 = time.time()
            train_state_obj, loss, mtm_loss, v_loss, acc = rl_mtm_train_step(train_state_obj, m_b, d_i, t_c, v_t)
            jax.block_until_ready(train_state_obj.params)
            t_dev1 = time.time()

            fetch_ms = (t_fetch1 - t_fetch0) * 1000.0
            device_ms = (t_dev1 - t_dev0) * 1000.0
            e2e_ms = (t_dev1 - t_fetch0) * 1000.0
            device_ms_list.append(device_ms)

            loss_f = float(loss); mtm_f = float(mtm_loss); v_f = float(v_loss); acc_f = float(acc) * 100.0
            loss_list.append(loss_f); mtm_list.append(mtm_f); v_list.append(v_f); acc_list.append(acc_f)

            if tracing:
                traced += 1
                if traced >= trace_steps:
                    jprof.stop_trace()
                    tracing = False

            if step % ui_every == 0:
                recent = np.array(device_ms_list[-min(200, len(device_ms_list)):], dtype=np.float32)
                p50 = float(np.percentile(recent, 50))
                p90 = float(np.percentile(recent, 90))
                p99 = float(np.percentile(recent, 99))
                sps = MAX_TRAIN_BATCH / max(1e-9, (device_ms / 1000.0))
                memstats = _try_device_memory_stats()

                _render_perf_panel(
                    ep, epochs, step, batches_per_epoch,
                    metrics={
                        "loss": loss_f, "mtm_loss": mtm_f, "v_loss": v_f, "acc": acc_f,
                        "sps": float(sps),
                        "batch_fetch_ms": float(fetch_ms),
                        "device_ms": float(device_ms),
                        "e2e_ms": float(e2e_ms),
                        "p50_ms": p50, "p90_ms": p90, "p99_ms": p99,
                    },
                    hw=hw,
                    memstats=str(memstats)
                )

        epoch_s = time.time() - epoch_t0
        if not loss_list:
            break

        device_recent = np.array(device_ms_list, dtype=np.float32)
        history["epoch"].append(ep)
        history["loss_epoch"].append(float(np.mean(loss_list)))
        history["mtm_loss_epoch"].append(float(np.mean(mtm_list)))
        history["v_loss_epoch"].append(float(np.mean(v_list)))
        history["acc_epoch"].append(float(np.mean(acc_list)))
        history["sps_epoch"].append(float((len(loss_list) * MAX_TRAIN_BATCH) / max(1e-9, epoch_s)))
        history["p50_ms_epoch"].append(float(np.percentile(device_recent, 50)))
        history["p90_ms_epoch"].append(float(np.percentile(device_recent, 90)))
        history["p99_ms_epoch"].append(float(np.percentile(device_recent, 99)))
        history["epoch_time_s"].append(float(epoch_s))

    clear_output(wait=True)
    print("✅ Training complete.")
    print("Hardware snapshot:", history["hardware"])
    if trace_dir:
        print(f"✅ Trace saved to: {trace_dir}")
    _safe_plots(history)
    return train_state_obj, history


In [14]:
# ==============================================================================
# CELL (ADD): INTERACTIVE PLAY LOOP (RESTORES play_interactive)
# ==============================================================================
import chess

def _safe_render(root, board, title, status):
    # If you have the dashboard cell, use it; otherwise print text board.
    if "render_dashboard" in globals():
        render_dashboard(root, board, title, status)
    else:
        print(title, "-", status)
        print(board)

def play_interactive(params, engine_plays_black=True, sims_per_move=60):
    """
    Interactive match vs engine.
    - engine_plays_black=True  -> Human plays White, engine plays Black (default)
    - sims_per_move controls MCTS budget
    """
    board = chess.Board()
    root = MCTSNode(board.copy())

    while not board.is_game_over():
        engine_to_move = (board.turn == chess.BLACK) if engine_plays_black else (board.turn == chess.WHITE)

        if not engine_to_move:
            _safe_render(root, board, "Interactive Match", "Waiting for Human Input...")
            while True:
                try:
                    move_input = input("Your Move (SAN, e.g., e4, Nf3, O-O) or 'quit': ").strip()
                    if move_input.lower() in ["quit", "exit", "resign"]:
                        print("Game Aborted.")
                        return

                    mv = board.parse_san(move_input)
                    board.push(mv)

                    # Sync/advance root if possible
                    if root.children and mv in root.children:
                        root = root.children[mv]
                        root.parent = None
                    else:
                        root = MCTSNode(board.copy())
                    break
                except Exception as e:
                    print("Invalid move. Try again. Error:", e)
        else:
            # Engine move
            _safe_render(root, board, "Interactive Match", f"Engine thinking... ({sims_per_move} sims)")
            run_mcts(root, params, num_simulations=sims_per_move, phase_title="Interactive Match", visualize=("render_dashboard" in globals()))

            if not root.children:
                print("Engine has no moves (resign/stalemate).")
                break

            # Prefer forced win if exists: child is_proven_loss => win for parent
            forced = [c for c in root.children.values() if getattr(c, "is_proven_loss", False)]
            if forced:
                best = forced[0]
            else:
                # Avoid known-losing moves if possible
                safe = [c for c in root.children.values() if not getattr(c, "is_proven_win", False)]
                if not safe:
                    safe = list(root.children.values())
                best = max(safe, key=lambda c: c.visits)

            board.push(best.move)
            root = best
            root.parent = None

    _safe_render(root, board, "Match Finished", f"Result: {board.result()}")
    print("Game Over:", board.result())


In [15]:
# ==============================================================================
# CELL (ADD): KAGGLE/JUPYTER FRIENDLY PLAY UI (ipywidgets)
#   - No input() needed
#   - Type SAN move (e.g., e4, Nf3, O-O) and click Submit
# ==============================================================================
import chess
import ipywidgets as widgets
from IPython.display import display, clear_output

def play_with_widgets(params, engine_plays_black=True, sims_per_move=60):
    # Persistent game state inside closure
    board = chess.Board()
    root = MCTSNode(board.copy())

    # UI elements
    move_box = widgets.Text(
        value="",
        placeholder="Enter SAN move (e.g., e4, Nf3, O-O)",
        description="Your move:",
        layout=widgets.Layout(width="420px"),
    )
    submit_btn = widgets.Button(description="Submit", button_style="success")
    reset_btn = widgets.Button(description="Reset", button_style="warning")
    status = widgets.HTML(value="")
    out = widgets.Output()

    def render(title, msg):
        # Prefer your dashboard if it exists; otherwise just print board
        if "render_dashboard" in globals():
            render_dashboard(root, board, title, msg)
        else:
            with out:
                clear_output(wait=True)
                print(title, "-", msg)
                print(board)

    def sync_root_after_move(mv):
        nonlocal root
        if root.children and mv in root.children:
            root = root.children[mv]
            root.parent = None
        else:
            root = MCTSNode(board.copy())

    def engine_move_if_needed():
        nonlocal root
        if board.is_game_over():
            return

        engine_to_move = (board.turn == chess.BLACK) if engine_plays_black else (board.turn == chess.WHITE)
        if not engine_to_move:
            return

        render("Interactive Match", f"Engine thinking... ({sims_per_move} sims)")
        run_mcts(root, params, num_simulations=sims_per_move, phase_title="Interactive Match",
                 visualize=("render_dashboard" in globals()))

        if not root.children:
            status.value = "<b>Engine has no moves.</b>"
            return

        # Prefer forced win move if exists: child is_proven_loss => win for parent
        forced = [c for c in root.children.values() if getattr(c, "is_proven_loss", False)]
        if forced:
            best = forced[0]
        else:
            safe = [c for c in root.children.values() if not getattr(c, "is_proven_win", False)]
            if not safe:
                safe = list(root.children.values())
            best = max(safe, key=lambda c: c.visits)

        board.push(best.move)
        root = best
        root.parent = None

    def on_submit(_):
        nonlocal root
        if board.is_game_over():
            status.value = f"<b>Game over:</b> {board.result()}"
            return

        txt = move_box.value.strip()
        if not txt:
            status.value = "<b>Please enter a SAN move.</b>"
            return

        try:
            mv = board.parse_san(txt)
            board.push(mv)
            sync_root_after_move(mv)
            move_box.value = ""
            status.value = f"✅ Played: <b>{txt}</b>"

            # Engine reply
            engine_move_if_needed()

            if board.is_game_over():
                status.value += f"<br><b>Game over:</b> {board.result()}"

            render("Interactive Match", "Your turn." if ((board.turn == chess.WHITE) == (not engine_plays_black)) else "Engine turn.")
        except Exception as e:
            status.value = f"❌ Invalid move: <b>{txt}</b><br><span style='font-size:12px'>{e}</span>"

    def on_reset(_):
        nonlocal board, root
        board = chess.Board()
        root = MCTSNode(board.copy())
        move_box.value = ""
        status.value = "Reset done."
        render("Interactive Match", "New game. Your turn.")

    submit_btn.on_click(on_submit)
    reset_btn.on_click(on_reset)

    # Initial render
    render("Interactive Match", "New game. Your turn.")
    display(widgets.HBox([move_box, submit_btn, reset_btn]))
    display(status)
    display(out)

# Usage:
# play_with_widgets(master_state.params, engine_plays_black=True, sims_per_move=60)


In [None]:
# ==============================================================================
# CELL 7: EXECUTION & LAUNCH (MULTI-DB TRAINING + PERF PLOTS)
# ==============================================================================
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
import chess

# 1) Put your 7 Kaggle PGN paths here:
PGN_FILES = [
    # Example (replace with your real Kaggle input paths):
    "/kaggle/input/datasets/yonnwoojeong/master-games/master_games.pgn",
    "/kaggle/input/datasets/yonnwoojeong/master-games/master_games (1).pgn",
    "/kaggle/input/datasets/yonnwoojeong/master-games/master_games (2).pgn",
    "/kaggle/input/datasets/yonnwoojeong/master-games/master_games (3).pgn",
    "/kaggle/input/datasets/yonnwoojeong/master-games/master_games (4).pgn",
    "/kaggle/input/datasets/yonnwoojeong/master-games/master_games (5).pgn",
    "/kaggle/input/datasets/yonnwoojeong/master-games/master_games (6).pgn",
]

BATCHES_PER_EPOCH = 50   # pick a number (e.g., 200~2000). Controls epoch length.
EPOCHS = 50

# Build dataloader factory (epoch-aware)
dataloader_factory = make_multi_pgn_dataloader(
    PGN_FILES,
    batches_per_epoch=BATCHES_PER_EPOCH,
    batch_size=MAX_TRAIN_BATCH,
    seed=42,
    shuffle_files_each_epoch=True,
)

# --- model init (ensure value head + mtm head params exist) ---
try:
    _ = master_state.params
    print("✅ Loaded existing TrainState with weights.")
except NameError:
    print("⚠️ No existing weights found. Initializing fresh Uni-Mamba model...")
    key = jax.random.PRNGKey(42)
    dummy_x = jnp.zeros((1, 64), dtype=jnp.int32)
    dummy_t = jnp.zeros((1,), dtype=jnp.int32)
    dummy_mask = jnp.zeros((1, 64, 1), dtype=jnp.float32).at[:, 0, 0].set(1.0)

    params = inference_model.init(key, dummy_x, dummy_t, mask_dest=dummy_mask)['params']
    master_state = train_state.TrainState.create(
        apply_fn=inference_model.apply,
        params=params,
        tx=optax.adamw(learning_rate=3e-4),
    )
    print("✅ Model initialized.")

# --- train + plot ---
master_state, history = train_supervised_phase(
    master_state,
    dataloader_factory=dataloader_factory,
    epochs=EPOCHS,
    batches_per_epoch=BATCHES_PER_EPOCH,
    ui_every=10,
)

# --- play ---
play_interactive(master_state.params)

⚠️ No existing weights found. Initializing fresh Uni-Mamba model...


E0000 00:00:1771071776.211944   32641 common_lib.cc:650] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:238


✅ Model initialized.


See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation.
