In [5]:
from ipywidgets import interactive, interact
import ipywidgets as widgets
from IPython.display import SVG, display
from matplotlib import colormaps

import json
import math
import chess

In [6]:
cmap = colormaps["PuRd"]

def hexify(f):
    assert f >= 0 and f <= 1
    return f"{int(f * 255):02x}"

def get_color(val):
    if val < 1e-3:
        return f"#ffffff77"
    r, g, b, _ = cmap(val)
    return f"#{hexify(r)}{hexify(g)}{hexify(b)}ff"    

In [7]:
with open("../trace1.json") as f:
    trace = json.load(f)

outcome = trace["outcome"]
steps = trace["steps"]
    
def f(step=widgets.IntSlider(min=0, max=len(steps), step=1, value=0)):
    b = chess.Board()
    for move in steps[0:step]:
        b.push(chess.Move.from_uci(move[0]))

    if step == len(steps):
        return b
        
    num_acts = [n for n, _ in steps[step][2]]
    sum_num = sum(num_acts) + 1e-4
    score = [v / sum_num for v in num_acts]
    distr = sorted(zip(b.legal_moves, score), key=lambda p: p[1], reverse=True)
    labels = widgets.HBox([
        widgets.Label(value=f"{m.uci()}", style=dict(background=get_color(c)))
        for m, c in distr
    ])
    out = widgets.Output()
    with out:
        display(SVG(data=b._repr_svg_()))
    return widgets.VBox([out, labels])

interact(f);

interactive(children=(IntSlider(value=0, description='step', max=200), Output()), _dom_classes=('widget-intera…

In [83]:
outcome

In [10]:
step = 160
b = chess.Board()
for move in steps[0:step]:
    b.push(chess.Move.from_uci(move[0]))
    
num_acts = [n for n, _ in steps[step][2]]
q_values = [v for _, v in steps[step][2]]
sum_num = sum(num_acts) + 1e-4
score = [v / sum_num for v in num_acts]
distr = sorted(zip(b.legal_moves, score, q_values, num_acts), key=lambda p: p[1], reverse=True)
distr

[(Move.from_uci('a1e5'), 0.15999936000256, -2.0, 4),
 (Move.from_uci('a1b2'), 0.11999952000191999, -1.0, 3),
 (Move.from_uci('d3e4'), 0.07999968000128, 0.0, 2),
 (Move.from_uci('d3e3'), 0.07999968000128, 0.0, 2),
 (Move.from_uci('d3c2'), 0.07999968000128, 0.0, 2),
 (Move.from_uci('a1g7'), 0.07999968000128, 0.0, 2),
 (Move.from_uci('a1f6'), 0.07999968000128, 0.0, 2),
 (Move.from_uci('a1d4'), 0.07999968000128, 0.0, 2),
 (Move.from_uci('a1c3'), 0.07999968000128, 0.0, 2),
 (Move.from_uci('f3f4'), 0.07999968000128, 0.0, 2),
 (Move.from_uci('d3d4'), 0.03999984000064, 1.0, 1),
 (Move.from_uci('a1h8'), 0.03999984000064, 1.0, 1)]

In [100]:
b = chess.Board()
for move in steps:
    b.push(chess.Move.from_uci(move[0]))

In [7]:
print(b.outcome(claim_draw=True))

None


In [8]:
outcome

In [4]:
import json
import chess
import libencoder

with open("../trace1.json", "r") as f:
    trace = json.load(f)

outcome = trace["outcome"]
steps = [(chess.Move.from_uci(step[0]), []) for step in trace["steps"]]

ds = libencoder.encode(steps)

In [14]:
from importlib import reload
import train
import libencoder
reload(train)
reload(libencoder)

<module 'libencoder' from '/home2/jiasen/workspace/smart-chess-rust/target/release/libencoder.so'>

In [11]:
import torch
import nn
model = nn.load_model(device="cpu", checkpoint="../runs/6/tb_logs/chess/version_1/checkpoints/epoch=3-step=2052.ckpt")

..loading checkpoint:  ../runs/6/tb_logs/chess/version_1/checkpoints/epoch=3-step=2052.ckpt


In [12]:
def infer(fn, index):
    ds = train.ChessDataset(fn)

    board_enc = ds[index][0]
    moves_enc = ds.steps[index][3]
  
    pi, value = model(board_enc.unsqueeze(0))
    pi = torch.exp(pi.detach()).squeeze()

    pi = pi[moves_enc]
    pi = pi / pi.sum()

    value = value.detach().squeeze()
    return pi, value

In [17]:
with open("../trace1.json") as f:
    trace = json.load(f)
    
step = 100
prior, value = infer("../trace1.json", step)
total_n = sum([n for n, _ in trace["steps"][step][2]])

In [18]:
total_n, prior

(55,
 tensor([0.0843, 0.0286, 0.0785, 0.0128, 0.0233, 0.0453, 0.0219, 0.0333, 0.0152,
         0.0282, 0.0252, 0.0191, 0.0519, 0.0927, 0.0470, 0.0119, 0.0730, 0.0709,
         0.0147, 0.0299, 0.0233, 0.0230, 0.0237, 0.0220, 0.0530, 0.0232, 0.0240]))

In [19]:
def uct(prior, total_n, q_value, current_n, reverse, cpuct):
    award = q_value / (current_n + 1e-5) * (-1 if reverse else 1)
    exploration = math.sqrt(total_n) / (1 + current_n) * prior * cpuct
    return (f"{award + exploration:0.2f}", f"{award:0.2f}", f"{exploration:0.2f}")

In [20]:
[(n, q, f"{prior[i].item():0.2f}", uct(prior[i].item(), total_n, q, n, False, 8)) for i, (n, q) in enumerate(trace["steps"][step][2])]

[(2, 0.0, '0.08', ('1.67', '0.00', '1.67')),
 (2, 0.0, '0.03', ('0.57', '0.00', '0.57')),
 (3, -1.0, '0.08', ('0.83', '-0.33', '1.16')),
 (2, 0.0, '0.01', ('0.25', '0.00', '0.25')),
 (2, 0.0, '0.02', ('0.46', '0.00', '0.46')),
 (2, 0.0, '0.05', ('0.90', '0.00', '0.90')),
 (3, -1.0, '0.02', ('-0.01', '-0.33', '0.33')),
 (3, -1.0, '0.03', ('0.16', '-0.33', '0.49')),
 (2, 0.0, '0.02', ('0.30', '0.00', '0.30')),
 (2, 0.0, '0.03', ('0.56', '0.00', '0.56')),
 (2, 0.0, '0.03', ('0.50', '0.00', '0.50')),
 (2, 0.0, '0.02', ('0.38', '0.00', '0.38')),
 (2, 0.0, '0.05', ('1.03', '0.00', '1.03')),
 (1, 1.0, '0.09', ('3.75', '1.00', '2.75')),
 (2, 0.0, '0.05', ('0.93', '0.00', '0.93')),
 (2, 0.0, '0.01', ('0.23', '0.00', '0.23')),
 (2, 0.0, '0.07', ('1.44', '0.00', '1.44')),
 (2, 0.0, '0.07', ('1.40', '0.00', '1.40')),
 (3, -1.0, '0.01', ('-0.12', '-0.33', '0.22')),
 (2, 0.0, '0.03', ('0.59', '0.00', '0.59')),
 (1, 1.0, '0.02', ('1.69', '1.00', '0.69')),
 (1, 1.0, '0.02', ('1.68', '1.00', '0.68')),
