In [1]:
# 1. Install the latest JAX TPU wheels (forcing a libtpu update)
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

!pip install -U chess

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import chess
import chess.svg
import numpy as np
import optax
import math
import time
from einops import rearrange
from IPython.display import display, clear_output, HTML



In [3]:
# ==============================================================================
# 1. OPTIMAL UNI-DIRECTIONAL ARCHITECTURE (COMPUTE-BOUND)
# ==============================================================================
import jax
import jax.numpy as jnp
import flax.linen as nn
from einops import rearrange
import numpy as np

BATCH_SIZE = 128
D_MODEL = 256
N_STATE = 64
TIMESTEPS = 100
NUM_CLASSES = 14

class UniMambaBlock(nn.Module):
    d_model: int
    n_state: int
    
    @nn.compact
    def __call__(self, x, time_emb):
        B, L, D = x.shape
        x_fused = x + time_emb[:, jnp.newaxis, :]
        
        A_log = self.param('A_log', nn.initializers.normal(0.1), (self.d_model, self.n_state))
        A = -jnp.exp(A_log)
        A_inv = 1.0 / (A + 1e-7) 
        
        # Native BFloat16 Projections
        B_proj = nn.Dense(self.n_state, dtype=jnp.bfloat16)(x_fused) 
        C_proj = nn.Dense(self.n_state, dtype=jnp.bfloat16)(x_fused)
        delta_proj = nn.Dense(self.d_model, dtype=jnp.bfloat16)(x_fused) 
        
        def scan_fn(h, params):
            b_t, c_t, delta_raw_t, x_t = params
            delta = nn.softplus(delta_raw_t.astype(jnp.float32))
            
            dt_A = jnp.einsum('bd,dn->bdn', delta, A)
            A_bar = jnp.exp(dt_A)
            
            B_bar = jnp.einsum('bdn,dn,bn->bdn', (A_bar - 1.0), A_inv, b_t.astype(jnp.float32))
            h_new = A_bar * h + jnp.einsum('bdn,bd->bdn', B_bar, x_t.astype(jnp.float32))
            y_t = jnp.einsum('bdn,bn->bd', h_new, c_t.astype(jnp.float32))
            
            return h_new, y_t.astype(jnp.bfloat16)

        Xs_fwd = (
            rearrange(B_proj, 'b l ... -> l b ...'),
            rearrange(C_proj, 'b l ... -> l b ...'),
            rearrange(delta_proj, 'b l ... -> l b ...'),
            rearrange(x_fused, 'b l ... -> l b ...')
        )
        
        # 8.38 MB Hidden State -> Fits safely in 16MB SRAM
        h_init = jnp.zeros((B, self.d_model, self.n_state), dtype=jnp.float32)
        
        _, y_f = jax.lax.scan(scan_fn, h_init, Xs_fwd)
        y_out = rearrange(y_f, 'l b d -> b l d')
        
        return nn.LayerNorm(dtype=jnp.bfloat16)(y_out + x)

class UniMambaChessEngine(nn.Module):
    @nn.compact
    def __call__(self, x_t, t, mask_dest=None): 
        x = nn.Embed(NUM_CLASSES, D_MODEL, dtype=jnp.bfloat16)(x_t)
        
        indices = jnp.arange(64)
        rank_emb = nn.Embed(8, D_MODEL, dtype=jnp.bfloat16)(jnp.tile(indices // 8, (x_t.shape[0], 1)))
        file_emb = nn.Embed(8, D_MODEL, dtype=jnp.bfloat16)(jnp.tile(indices % 8, (x_t.shape[0], 1)))
        
        x = x + rank_emb + file_emb
        x = nn.Dense(D_MODEL, dtype=jnp.bfloat16)(x) 
        
        t_emb = nn.Embed(TIMESTEPS, D_MODEL, dtype=jnp.bfloat16)(t)
        Block = nn.remat(UniMambaBlock)
            
        for _ in range(4): 
            x = Block(D_MODEL, N_STATE)(x, t_emb)
            x = nn.gelu(x) 
            
        if mask_dest is not None:
            x = jnp.sum(x * mask_dest.astype(jnp.bfloat16), axis=1) 
            return nn.Dense(NUM_CLASSES, dtype=jnp.float32)(x)    
            
        return nn.Dense(NUM_CLASSES, dtype=jnp.float32)(x)

inference_model = UniMambaChessEngine()

In [4]:
# ==============================================================================
# CELL 2: DENSE ALGEBRAIC KERNELS (100% XLA FUSION)
# ==============================================================================
import jax
import jax.numpy as jnp
import optax
import numpy as np

def board_to_tensor(board):
    x = np.zeros(64, dtype=np.int32)
    for i in range(64):
        p = board.piece_at(i)
        if p: x[i] = p.piece_type + (0 if p.color else 6)
    return x

@jax.jit
def calculate_mtm_energies(frozen_params, candidate_batch, dest_indices, piece_classes):
    num_candidates = candidate_batch.shape[0]
    
    mask_dest = jax.nn.one_hot(dest_indices, 64, dtype=jnp.int32)
    masked_batch = candidate_batch * (1 - mask_dest) + 13 * mask_dest
    t_batch = jnp.full((num_candidates,), 1, dtype=jnp.int32)
    
    mask_dest_expanded = jnp.expand_dims(mask_dest, -1).astype(jnp.float32)
    dest_logits = inference_model.apply(
        {'params': frozen_params}, masked_batch, t_batch, mask_dest=mask_dest_expanded
    )
    
    probs = jax.nn.softmax(dest_logits, axis=-1)
    mask_piece = jax.nn.one_hot(piece_classes, 14, dtype=jnp.float32)
    actual_piece_probs = jnp.sum(probs * mask_piece, axis=-1)
    
    return -jnp.log(actual_piece_probs + 1e-7), probs

@jax.jit
def rl_mtm_train_step(state, masked_boards, dest_indices, target_classes):
    def loss_fn(params):
        num_samples = masked_boards.shape[0]
        t_batch = jnp.full((num_samples,), 1, dtype=jnp.int32)
        
        mask_dest = jax.nn.one_hot(dest_indices, 64, dtype=jnp.float32)
        mask_dest_expanded = jnp.expand_dims(mask_dest, -1)
        
        target_logits = inference_model.apply(
            {'params': params}, masked_boards, t_batch, mask_dest=mask_dest_expanded
        )
        
        target_one_hot = jax.nn.one_hot(target_classes, 14, dtype=jnp.float32)
        log_probs = jax.nn.log_softmax(target_logits)
        return -jnp.sum(target_one_hot * log_probs, axis=-1).mean()

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads), loss

In [5]:
# ==============================================================================
# 3. MCTS NODE STRUCTURE WITH ALPHA-BETA OVERRIDES
# ==============================================================================
class MCTSNode:
    def __init__(self, state, parent=None, move=None, prior=0.0):
        self.state = state; self.parent = parent; self.move = move
        self.prior = prior; self.visits = 0; self.value_sum = 0.0
        self.children = {}; self.is_expanded = False
        self.is_proven_win = False  # Alpha-Beta Pruning
        self.is_proven_loss = False # Alpha-Beta Pruning
        self.xray_probs = None; self.xray_piece = None
        
    @property
    def q_value(self):
        if self.is_proven_win: return 1.0
        if self.is_proven_loss: return -1.0
        return self.value_sum / self.visits if self.visits > 0 else 0.0

    def ucb(self, c_puct):
        if self.is_proven_win: return float('inf')
        if self.is_proven_loss: return -float('inf')
        return self.q_value + c_puct * self.prior * (math.sqrt(self.parent.visits) / (1 + self.visits))

In [6]:
# ==============================================================================
# 4. THE OMNISCIENT DASHBOARD (UI VISUALIZER)
# ==============================================================================
PIECE_NAMES = ["Empty", "W_P", "W_N", "W_B", "W_R", "W_Q", "W_K", 
               "B_P", "B_N", "B_B", "B_R", "B_Q", "B_K", "MASK"]

def render_dashboard(root, board, phase_title, sim_status, loss_info=""):
    """Builds an HTML/SVG interface exposing all underlying math."""
    children = sorted(root.children.values(), key=lambda c: c.visits, reverse=True)
    fill, arrows = {}, []
    
    # Generate Heatmap & PV Arrows
    if children:
        total_visits = sum(c.visits for c in children) + 1e-5
        for child in children:
            opacity = int((child.visits / total_visits) * 255)
            fill[child.move.to_square] = f"#0055ff{opacity:02x}" 
            
        curr, depth = root, 0
        while curr.children and depth < 4:
            valid = [c for c in curr.children.values() if not c.is_proven_loss]
            if not valid: break
            best = max(valid, key=lambda c: c.visits)
            arrows.append(chess.svg.Arrow(best.move.from_square, best.move.to_square, color="#28a745"))
            curr = best; depth += 1

    board_svg = chess.svg.board(board, arrows=arrows, fill=fill, size=400, lastmove=board.peek() if board.move_stack else None)
    
    # Calculate Overall Eval
    eval_score = ((children[0].q_value + 1) / 2 * 100) if children else 50.0
    eval_color = "green" if eval_score > 55 else "red" if eval_score < 45 else "black"
    
    # HTML Assembly
    html = f"""
    <div style='display: flex; gap: 20px; font-family: sans-serif; max-width: 900px; margin: auto; padding: 10px; border: 1px solid #ddd; border-radius: 8px;'>
        <div>{board_svg}</div>
        <div style='flex-grow: 1;'>
            <h2 style='margin-top: 0; color: #333;'>{phase_title}</h2>
            <div style='background: #f8f9fa; padding: 10px; border-radius: 5px; margin-bottom: 10px;'>
                <b>Status:</b> {sim_status} <br>
                <b>Engine Eval:</b> <span style='color: {eval_color}; font-weight: bold;'>{eval_score:.1f}% Win Probability</span>
                {f"<br><b>RL Update Loss:</b> {loss_info}" if loss_info else ""}
            </div>
            <h4 style='border-bottom: 1px solid #ccc; padding-bottom: 5px;'>X-Ray Internal Monologue</h4>
            <div style='max-height: 280px; overflow-y: auto; font-size: 13px;'>
    """
    
    # Detail the top 3 lines
    for i, c in enumerate(children[:3]):
        ab_flag = "<span style='color: red; font-weight: bold;'>[FORCED LOSS PRUNED]</span>" if c.is_proven_loss else ""
        ab_flag = "<span style='color: green; font-weight: bold;'>[FORCED WIN DETECTED]</span>" if c.is_proven_win else ab_flag
        
        html += f"<div style='margin-bottom: 12px; padding: 8px; border-left: 3px solid #0055ff; background: #fdfdfd;'>"
        html += f"<b>Line {i+1}: {board.san(c.move)}</b> {ab_flag}<br>"
        html += f"<span style='color: #666;'>Visits: {c.visits} | Q-Value: {c.q_value:.2f} | Mamba-2 Prior: {c.prior*100:.1f}%</span><br>"
        
        # MTM Predictions for this move
        if c.xray_probs is not None:
            html += f"<div style='margin-top: 4px; font-family: monospace; font-size: 11px; color: #444;'>"
            html += f"<b>[MASK] Prediction (Target Sq):</b><br>"
            top_preds = np.argsort(c.xray_probs)[::-1][:2]
            for idx in top_preds:
                mark = " &lt;-- Actual" if idx == c.xray_piece else ""
                html += f"&nbsp;&nbsp;{PIECE_NAMES[idx]}: {c.xray_probs[idx]*100:.1f}%{mark}<br>"
            if c.xray_piece not in top_preds:
                html += f"&nbsp;&nbsp;... {PIECE_NAMES[c.xray_piece]}: {c.xray_probs[c.xray_piece]*100:.2f}% &lt;-- Actual<br>"
            html += "</div>"
        html += "</div>"
        
    html += "</div></div></div>"
    
    clear_output(wait=True)
    display(HTML(html))

In [7]:
# ==============================================================================
# 5. CORE MCTS EXECUTION ALGORITHM
# ==============================================================================
def run_mcts(root, params, num_simulations=100, phase_title="Engine Analysis", visualize=False):
    key = jax.random.PRNGKey(int(time.time()))
    for i in range(num_simulations):
        node = root
        
        # 1. Selection & Alpha-Beta Overrides
        while node.is_expanded and node.children:
            valid_children = [c for c in node.children.values() if not c.is_proven_loss]
            if not valid_children:
                node.is_proven_loss = True
                break
            node = max(valid_children, key=lambda c: c.ucb(c_puct=2.0))
            
        if node.is_proven_loss or node.is_proven_win: 
            continue
            
        # 2. Expansion (TPU Deterministic MTM)
        if not node.state.is_game_over():
            moves = list(node.state.legal_moves)
            candidates, dest_indices, piece_classes = [], [], []
            
            for m in moves:
                node.state.push(m)
                candidates.append(board_to_tensor(node.state))
                dest_indices.append(m.to_square)
                p = node.state.piece_at(m.to_square)
                piece_classes.append(p.piece_type + (0 if p.color else 6))
                node.state.pop()
                
            if candidates:
                # TPU Call
                energies, full_probs = calculate_mtm_energies(
                    params, jnp.array(candidates), jnp.array(dest_indices, dtype=jnp.int32), jnp.array(piece_classes, dtype=jnp.int32)
                )
                energies, full_probs = np.array(energies), np.array(full_probs)
                
                # Temperature sharpening
                priors = np.exp(-energies / 0.5)
                priors /= np.sum(priors)
                
                # Tree instantiation
                for idx, (m, p) in enumerate(zip(moves, priors)):
                    new_board = node.state.copy()
                    new_board.push(m)
                    child_node = MCTSNode(new_board, parent=node, move=m, prior=p)
                    child_node.xray_probs = full_probs[idx]
                    child_node.xray_piece = piece_classes[idx]
                    node.children[m] = child_node
                    
                # Inside run_mcts expansion block:
                node.is_expanded = True
                
                # THE MATHEMATICAL FIX: V(S) = 2 * exp(-E) - 1
                # We find the candidate move that maximizes this bounded value
                value = float(np.max(2.0 * np.exp(-energies) - 1.0))
            else: 
                value = 0.0
        else:
            outcome = node.state.outcome()
            if outcome.winner is not None:
                value = 1.0 if outcome.winner == node.state.turn else -1.0
                if value == 1.0: node.is_proven_win = True
                else: node.is_proven_loss = True
            else: value = 0.0
            
        # 3. Minimax Backpropagation
        curr = node
        while curr.parent:
            curr.visits += 1
            curr.value_sum += value
            value = -value 
            if curr.is_proven_win: curr.parent.is_proven_loss = True
            curr = curr.parent
        root.visits += 1
        
        # UI Hook (Render during search)
        if visualize and i % 10 == 0:
            render_dashboard(root, root.state, phase_title, f"Simulating {i}/{num_simulations}...")
            
    if visualize:
        render_dashboard(root, root.state, phase_title, "Analysis Complete.")

In [8]:
# ==============================================================================
# CELL 6: MASSIVELY PARALLEL RLHF PRE-TRAINING (TRACKING & VISUALIZATION)
# ==============================================================================
import time
import numpy as np
import jax.numpy as jnp
import jax
import chess
from IPython.display import clear_output
import matplotlib.pyplot as plt 

NUM_PARALLEL_GAMES = 128
SIMS_PER_MOVE = 30

# DECOUPLED MEMORY BOUNDS (Mathematically derived for 16GB HBM)
MAX_INFERENCE_BATCH = 512  # Forward pass is O(B*D*N) -> Safe to push high
MAX_TRAIN_BATCH = 128      # Backward pass is O(L*B*D*N) -> Must be constrained

def batched_self_play_rlhf(train_state_obj, num_epochs=5, visualize_game=True):
    """Executes 128 games concurrently with live UI visualization and hardware tracking."""
    print(f"--- üöÄ INITIATING {NUM_PARALLEL_GAMES} PARALLEL RLHF GAMES ---")
    
    # --------------------------------------------------------------------------
    # 1. HARDWARE PROFILING (AOT COMPILATION CHECK)
    # --------------------------------------------------------------------------
    print(f"\n[Hardware] Compiling JIT Train Step (Batch Size: {MAX_TRAIN_BATCH}) and profiling memory bounds...")
    dummy_boards = jnp.zeros((MAX_TRAIN_BATCH, 64), dtype=jnp.int32)
    dummy_dests = jnp.zeros((MAX_TRAIN_BATCH,), dtype=jnp.int32)
    dummy_targets = jnp.zeros((MAX_TRAIN_BATCH,), dtype=jnp.int32)
    
    compiled_train_step = rl_mtm_train_step.lower(
        train_state_obj, dummy_boards, dummy_dests, dummy_targets
    ).compile()
    
    mem_profile = compiled_train_step.memory_analysis()
    temp_mem_mb = mem_profile.temp_size_in_bytes / (1024**2)
    
    print(f"[Hardware] TPU Backward Pass Temp Allocation: {temp_mem_mb:.2f} MB")
    if temp_mem_mb < 8000.0: # ~8GB is extremely safe for a 16GB TPU
        print("[Hardware] ‚úÖ Safe BPTT Memory Pinning Successful.")
    else:
        print("[Hardware] ‚ö†Ô∏è WARNING: Elevated HBM spillage detected in backward pass.")
        
    print("\n[Hardware] Warming up MXU registers...\n")
    _ = rl_mtm_train_step(train_state_obj, dummy_boards, dummy_dests, dummy_targets)
    jax.block_until_ready(_)
    time.sleep(1) 

    # --------------------------------------------------------------------------
    # 2. SOFTWARE & HARDWARE METRICS TRACKING
    # --------------------------------------------------------------------------
    history = {'loss': [], 'sps': [], 'epoch_time': []}
    
    for epoch in range(num_epochs):
        start_time = time.time()
        
        boards = [chess.Board() for _ in range(NUM_PARALLEL_GAMES)]
        roots = [MCTSNode(b.copy()) for b in boards]
        game_histories = [[] for _ in range(NUM_PARALLEL_GAMES)]
        active_games = set(range(NUM_PARALLEL_GAMES))
        
        move_number = 1
        total_forward_passes = 0
        forward_pass_time = 0.0
        
        while active_games and move_number < 80:
            # --- LIVE VISUALIZATION (Watching Game 0) ---
            if visualize_game and 0 in active_games and (move_number % 2 == 1):
                render_dashboard(
                    roots[0], boards[0], 
                    f"Self-Play Epoch {epoch+1} | Move {move_number}", 
                    f"Active Games: {len(active_games)} | Simulating Engine Lines..."
                )
            elif not visualize_game:
                print(f"Epoch {epoch+1} | Move {move_number} | Active Games: {len(active_games)}...", end="\r")
            
            for sim in range(SIMS_PER_MOVE):
                batch_candidates, batch_dest, batch_pieces, ptrs = [], [], [], []
                
                # A. Vectorized Selection
                for i in active_games:
                    node = roots[i]
                    while node.is_expanded and node.children:
                        valid = [c for c in node.children.values() if not c.is_proven_loss]
                        if not valid: 
                            node.is_proven_loss = True
                            break
                        node = max(valid, key=lambda c: c.ucb(c_puct=2.0))
                        
                    if node.is_proven_loss or node.is_proven_win or node.state.is_game_over():
                        continue
                        
                    for m in node.state.legal_moves:
                        node.state.push(m)
                        batch_candidates.append(board_to_tensor(node.state))
                        batch_dest.append(m.to_square)
                        p = node.state.piece_at(m.to_square)
                        batch_pieces.append(p.piece_type + (0 if p.color else 6))
                        node.state.pop()
                        
                    ptrs.append((i, node, len(list(node.state.legal_moves))))
                
                # B. TPU Acceleration (Using INFERENCE micro-batch size)
                if batch_candidates:
                    total_evals = len(batch_candidates)
                    all_energies, all_probs = [], []
                    
                    t0 = time.time() 
                    for chunk_idx in range(0, total_evals, MAX_INFERENCE_BATCH):
                        chunk_end = min(chunk_idx + MAX_INFERENCE_BATCH, total_evals)
                        
                        c_chunk = jnp.array(batch_candidates[chunk_idx:chunk_end])
                        d_chunk = jnp.array(batch_dest[chunk_idx:chunk_end], dtype=jnp.int32)
                        p_chunk = jnp.array(batch_pieces[chunk_idx:chunk_end], dtype=jnp.int32)
                        
                        e_chunk, prob_chunk = calculate_mtm_energies(
                            train_state_obj.params, c_chunk, d_chunk, p_chunk
                        )
                        jax.block_until_ready(e_chunk)
                        
                        all_energies.append(np.array(e_chunk))
                        all_probs.append(np.array(prob_chunk))
                        
                    forward_pass_time += (time.time() - t0)
                    total_forward_passes += total_evals
                    
                    energies = np.concatenate(all_energies, axis=0)
                    full_probs = np.concatenate(all_probs, axis=0)
                    
                    # C. Vectorized Expansion
                    idx_offset = 0
                    for (game_id, node, num_moves) in ptrs:
                        node_energies = energies[idx_offset : idx_offset + num_moves]
                        node_probs = full_probs[idx_offset : idx_offset + num_moves]
                        idx_offset += num_moves
                        
                        priors = np.exp(-node_energies / 0.5)
                        priors /= np.sum(priors)
                        
                        for move_idx, m in enumerate(node.state.legal_moves):
                            new_board = node.state.copy()
                            new_board.push(m)
                            child = MCTSNode(new_board, parent=node, move=m, prior=priors[move_idx])
                            child.xray_probs = node_probs[move_idx]
                            
                            p_type = new_board.piece_at(m.to_square)
                            child.xray_piece = p_type.piece_type + (0 if p_type.color else 6)
                            node.children[m] = child
                            
                        node.is_expanded = True
                        value = float(np.max(2.0 * np.exp(-node_energies) - 1.0))
                        
                        curr = node
                        while curr.parent:
                            curr.visits += 1
                            curr.value_sum += value
                            value = -value 
                            if curr.is_proven_win: curr.parent.is_proven_loss = True
                            curr = curr.parent
                        roots[game_id].visits += 1

            # D. Advance All Boards
            finished = []
            for i in active_games:
                b, root = boards[i], roots[i]
                valid = [c for c in root.children.values() if not c.is_proven_loss]
                if not valid:
                    finished.append(i)
                    continue
                    
                best_node = max(valid, key=lambda c: c.visits)
                
                m_board = board_to_tensor(best_node.state).copy()
                m_board[best_node.move.to_square] = 13
                game_histories[i].append((m_board, best_node.move.to_square, best_node.xray_piece, b.turn))
                
                b.push(best_node.move)
                roots[i] = best_node
                roots[i].parent = None
                
                if b.is_game_over(): finished.append(i)
                    
            for i in finished: active_games.remove(i)
            move_number += 1
            
        # E. Expert Iteration Gradient Update (Using TRAIN micro-batch size)
        opt_data = []
        for i, b in enumerate(boards):
            out = b.outcome()
            if out and out.winner is not None:
                opt_data.extend([d for d in game_histories[i] if d[3] == out.winner])
                
        if opt_data:
            m_b, d_i, t_c = map(jnp.array, zip(*[(d[0], d[1], d[2]) for d in opt_data]))
            losses = []
            
            for b_start in range(0, len(opt_data), MAX_TRAIN_BATCH):
                b_end = min(b_start + MAX_TRAIN_BATCH, len(opt_data))
                train_state_obj, loss = rl_mtm_train_step(
                    train_state_obj, 
                    m_b[b_start:b_end], 
                    d_i[b_start:b_end], 
                    t_c[b_start:b_end]
                )
                jax.block_until_ready(train_state_obj.params)
                losses.append(loss)
                
            duration = time.time() - start_time
            avg_loss = float(np.mean(losses))
            sps = total_forward_passes / forward_pass_time if forward_pass_time > 0 else 0
            
            history['loss'].append(avg_loss)
            history['sps'].append(sps)
            history['epoch_time'].append(duration)
            
            if not visualize_game:
                print(f"\nEpoch {epoch+1} Completed in {duration:.1f}s | RLHF Loss: {avg_loss:.4f} | Throughput: {sps:,.0f} SPS\n")
            else:
                render_dashboard(
                    roots[0], boards[0], 
                    f"Epoch {epoch+1} Completed", 
                    f"Trained on {len(opt_data)} optimums | Throughput: {sps:,.0f} SPS",
                    loss_info=f"{avg_loss:.4f}"
                )
                time.sleep(2) 
            
    return train_state_obj, history

In [9]:
# ==============================================================================
# CELL 7: EXECUTION & LAUNCH
# ==============================================================================
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
import chess
import time
import matplotlib.pyplot as plt

# --- 1. DEFINE INTERACTIVE GAME LOOP ---
def play_interactive(params):
    """Launches the Human vs Mamba-2 Match with the Omniscient Dashboard."""
    board = chess.Board()
    root = MCTSNode(board.copy())

    while not board.is_game_over():
        if board.turn == chess.WHITE:
            render_dashboard(root, board, "Interactive Match", "Waiting for Human Input...")
            while True:
                try:
                    move_input = input("Your Move (e.g., e4, Nf3): ")
                    if move_input.lower() in ["quit", "exit", "resign"]:
                        print("Game Aborted.")
                        return
                    move = board.push_san(move_input)
                    if move in root.children:
                        root = root.children[move]
                        root.parent = None 
                    else:
                        root = MCTSNode(board.copy())
                    break
                except ValueError:
                    print("Invalid move format. Please use Standard Algebraic Notation (e.g., e4).")
        else:
            run_mcts(root, params, num_simulations=60, phase_title="Interactive Match", visualize=True)
            valid_children = [c for c in root.children.values() if not c.is_proven_loss]
            if not valid_children:
                print("Mamba-2 Resigns.")
                break
            best_node = max(valid_children, key=lambda c: c.visits)
            board.push(best_node.move)
            root = best_node
            root.parent = None

    render_dashboard(root, board, "Match Finished", f"Result: {board.result()}")
    print("Game Over")

# --- 2. STATE INITIALIZATION ---
try:
    _ = master_state.params
    print("‚úÖ Loaded existing TrainState with trained weights.")
except NameError:
    print("‚ö†Ô∏è No existing weights found. Initializing fresh Uni-Mamba model...")
    key = jax.random.PRNGKey(42)
    dummy_x = jnp.zeros((1, 64), dtype=jnp.int32)
    dummy_t = jnp.zeros((1,), dtype=jnp.int32)
    params = inference_model.init(key, dummy_x, dummy_t)['params']
    
    master_state = train_state.TrainState.create(
        apply_fn=inference_model.apply, 
        params=params, 
        tx=optax.adamw(learning_rate=3e-4)
    )
    print("‚úÖ Model initialized.")

# --- 3. MISSION CONTROL ---

# [MODE A] MASSIVELY PARALLEL PRE-TRAINING
# By setting visualize_game=True, the engine will render the ongoing self-play matches
master_state, metrics = batched_self_play_rlhf(master_state, num_epochs=10, visualize_game=True)

# Visualize Hardware & Software Learning Curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(metrics['loss'], marker='o', color='red')
plt.title('RLHF Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Cross-Entropy Loss')

plt.subplot(1, 2, 2)
plt.plot(metrics['sps'], marker='o', color='blue')
plt.title('Hardware Throughput (SPS)')
plt.xlabel('Epoch')
plt.ylabel('Steps Per Second')
plt.tight_layout()
plt.show()

# [MODE B] INTERACTIVE PLAY (Uncomment to play against the updated weights)
play_interactive(master_state.params)

KeyboardInterrupt: 