In [24]:
from ipywidgets import interactive, interact
import ipywidgets as widgets
from IPython.display import SVG, display
from matplotlib import colormaps

import json
import math
import chess

In [25]:
cmap = colormaps["PuRd"]

def hexify(f):
    assert f >= 0 and f <= 1
    return f"{int(f * 255):02x}"

def get_color(val):
    if val < 1e-3:
        return f"#ffffff77"
    r, g, b, _ = cmap(val)
    return f"#{hexify(r)}{hexify(g)}{hexify(b)}ff"    

In [39]:
# filename = "../runs/69/trace1.json"
filename = "../trace1.json"

In [40]:
with open(filename) as f:
    trace = json.load(f)

outcome = trace["outcome"]
steps = trace["steps"]
    
def f(step=widgets.IntSlider(min=0, max=len(steps), step=1, value=0)):
    b = chess.Board()
    for move in steps[0:step]:
        b.push(chess.Move.from_uci(move[0]))

    if step == len(steps):
        return b
        
    num_acts = [n for n, _ in steps[step][2]]
    sum_num = sum(num_acts) + 1e-4
    score = [v / sum_num for v in num_acts]
    distr = sorted(zip(b.legal_moves, score), key=lambda p: p[1], reverse=True)
    labels = widgets.HBox([
        widgets.Label(value=f"{m.uci()}", style=dict(background=get_color(c)))
        for m, c in distr
    ])
    out = widgets.Output()
    with out:
        display(SVG(data=b._repr_svg_()))
    return widgets.VBox([out, labels])

interact(f);

interactive(children=(IntSlider(value=0, description='step', max=300), Output()), _dom_classes=('widget-intera…

In [29]:
outcome

{'termination': 'Checkmate', 'winner': 'Black'}

In [37]:
def replay(trace, step):
    steps = trace["steps"]
    b = chess.Board()
    for move in steps[0:step]:
        b.push(chess.Move.from_uci(move[0]))

    num_acts = [n for n, _ in steps[step][2]]
    q_values = [v for _, v in steps[step][2]]
    sum_num = sum(num_acts) + 1e-4
    score = [v / sum_num for v in num_acts]
    return list(zip(b.legal_moves, score, q_values, num_acts))

In [41]:
distr = replay(trace, 230)
distr = sorted(distr, key=lambda p: p[1], reverse=True)
distr

[(Move.from_uci('a6c4'), 0.11797746181041471, -3.0, 21),
 (Move.from_uci('a3b4'), 0.07303366683501863, 1.0, 13),
 (Move.from_uci('a6c8'), 0.0561797437192451, 0.0, 10),
 (Move.from_uci('a6e2'), 0.0561797437192451, 0.0, 10),
 (Move.from_uci('a6b7'), 0.05056176934732059, 1.0, 9),
 (Move.from_uci('b3c3'), 0.05056176934732059, -1.0, 9),
 (Move.from_uci('a6g6'), 0.04494379497539608, -2.0, 8),
 (Move.from_uci('a6a5'), 0.04494379497539608, 2.0, 8),
 (Move.from_uci('a6b6'), 0.03932582060347157, -1.0, 7),
 (Move.from_uci('a6b5'), 0.03932582060347157, -1.0, 7),
 (Move.from_uci('a6f1'), 0.03932582060347157, -1.0, 7),
 (Move.from_uci('a6e6'), 0.03370784623154706, -2.0, 6),
 (Move.from_uci('a3b2'), 0.03370784623154706, 0.0, 6),
 (Move.from_uci('a6a8'), 0.02808987185962255, -1.0, 5),
 (Move.from_uci('a6a7'), 0.02808987185962255, -1.0, 5),
 (Move.from_uci('a6f6'), 0.02808987185962255, -1.0, 5),
 (Move.from_uci('a6d6'), 0.02808987185962255, -1.0, 5),
 (Move.from_uci('a6d3'), 0.02808987185962255, -1.0, 

In [100]:
b = chess.Board()
for move in steps:
    b.push(chess.Move.from_uci(move[0]))

In [39]:
print(b.outcome(claim_draw=True))

None


In [40]:
outcome

{'termination': 'Checkmate', 'winner': 'Black'}

In [41]:
import json
import chess
import libencoder

with open("../trace1.json", "r") as f:
    trace = json.load(f)

outcome = trace["outcome"]
steps = [(chess.Move.from_uci(step[0]), []) for step in trace["steps"]]

ds = libencoder.encode(steps)

In [42]:
from importlib import reload
import train
import libencoder
reload(train)
reload(libencoder)

<module 'libencoder' from '/home2/jiasen/workspace/smart-chess-rust/target/release/libencoder.so'>

In [44]:
import torch
import nn
model = nn.load_model(device="cpu", checkpoint="../runs/121/tb_logs/chess/version_0/checkpoints/epoch=2-step=3126.ckpt")

..loading checkpoint:  ../runs/121/tb_logs/chess/version_0/checkpoints/epoch=2-step=3126.ckpt


In [45]:
def infer(fn, index):
    ds = train.ChessDataset(fn)

    board_enc = ds[index][0]
    moves_enc = ds.steps[index][3]
  
    pi, value = model(board_enc.unsqueeze(0))
    pi = torch.exp(pi.detach()).squeeze()

    pi = pi[moves_enc]
    pi = pi / pi.sum()

    value = value.detach().squeeze()
    return pi, value

In [46]:
with open(filename) as f:
    trace = json.load(f)
    
step = 230
prior, value = infer(filename, step)
total_n = sum([n for n, _ in trace["steps"][step][2]])

distr = replay(trace, step)

skipping cudagraphs for unknown reason


In [47]:
total_n, prior

(178,
 tensor([0.0327, 0.0037, 0.0161, 0.0025, 0.0546, 0.1160, 0.0126, 0.0099, 0.0118,
         0.1965, 0.0166, 0.0030, 0.0776, 0.0054, 0.0264, 0.0088, 0.0096, 0.0798,
         0.0512, 0.0961, 0.0037, 0.0646, 0.0112, 0.0364, 0.0198, 0.0334]))

In [48]:
next_moves = [n[0] for n in distr]

In [49]:
def uct(prior, total_n, q_value, current_n, reverse, cpuct):
    award = q_value / (current_n + 1e-5) * (-1 if reverse else 1)
    exploration = math.sqrt(total_n) / (1 + current_n) * prior * cpuct
    return (f"{award + exploration:0.2f}", f"{award:0.2f}", f"{exploration:0.2f}")

In [53]:
[(next_moves[i], n, q, f"{prior[i].item():0.3f}", uct(prior[i].item(), total_n, q, n, False, 2.0)) for i, (n, q) in enumerate(trace["steps"][step][2])]

[(Move.from_uci('a6c8'), 10, 0.0, '0.033', ('0.08', '0.00', '0.08')),
 (Move.from_uci('a6a8'), 5, -1.0, '0.004', ('-0.18', '-0.20', '0.02')),
 (Move.from_uci('a6b7'), 9, 1.0, '0.016', ('0.15', '0.11', '0.04')),
 (Move.from_uci('a6a7'), 5, -1.0, '0.003', ('-0.19', '-0.20', '0.01')),
 (Move.from_uci('a6h6'), 3, -1.0, '0.055', ('0.03', '-0.33', '0.36')),
 (Move.from_uci('a6g6'), 8, -2.0, '0.116', ('0.09', '-0.25', '0.34')),
 (Move.from_uci('a6f6'), 5, -1.0, '0.013', ('-0.14', '-0.20', '0.06')),
 (Move.from_uci('a6e6'), 6, -2.0, '0.010', ('-0.30', '-0.33', '0.04')),
 (Move.from_uci('a6d6'), 5, -1.0, '0.012', ('-0.15', '-0.20', '0.05')),
 (Move.from_uci('a6c6'), 3, -1.0, '0.197', ('0.98', '-0.33', '1.31')),
 (Move.from_uci('a6b6'), 7, -1.0, '0.017', ('-0.09', '-0.14', '0.06')),
 (Move.from_uci('a6b5'), 7, -1.0, '0.003', ('-0.13', '-0.14', '0.01')),
 (Move.from_uci('a6a5'), 8, 2.0, '0.078', ('0.48', '0.25', '0.23')),
 (Move.from_uci('a6c4'), 21, -3.0, '0.005', ('-0.14', '-0.14', '0.01')),
 (

In [59]:
uct(prior[13].item(), 120, -3.0, 10, False, 2.0)

('-0.29', '-0.30', '0.01')

In [52]:
total_n

178