In [1]:
!pip install onnx onnxruntime



In [2]:
import argparse
import os
import sys
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from google.colab import drive
from tqdm.notebook import tqdm

64 input layer (board)
3 hidden layers
128
64
32
64x1 probability vector (softmax)

In [8]:
class H5OthelloDataset(Dataset):
    """
    Dataset for Othello training data stored in HDF5.
    Expects each sample in a group named 'sample_{i}' with:
      - dataset 'board' (64,) int8 or int32
      - dataset 'probs' (64,) float32
      - attribute 'outcome' (scalar int)
    """
    def __init__(self, h5_path):
        self.h5_path = h5_path
        # Open HDF5 file once
        self.file = h5py.File(h5_path, 'r')
        # List and sort group keys
        self.keys = list(self.file.keys())
        try:
            # if named sample_0, sample_1, ...
            self.keys.sort(key=lambda x: int(x.split('_')[1]))
        except Exception:
            self.keys.sort()

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        grp = self.file[self.keys[idx]]
        # Load numpy arrays
        board_vec = grp['board'][()]    # shape (64,)
        pi_vec    = grp['probs'][()]    # shape (64,)
        outcome   = grp.attrs['outcome'] # scalar

        # Reshape to 8x8 and build 2-channel tensor
        my_stones  = (board_vec ==  1).astype(np.float32)
        opp_stones = (board_vec == -1).astype(np.float32)
        x = np.stack([my_stones, opp_stones], axis=0)  # (2,64)

        # Policy target and legal-mask
        pi   = pi_vec.astype(np.float32)               # (64,)
        mask = (pi > 0)                                # (64,) bool

        # Value target
        y_value = np.float32(outcome)

        return (
            torch.from_numpy(x),               # float32 tensor (2,64)
            torch.from_numpy(pi),              # float32 tensor (64,)
            torch.from_numpy(mask),            # bool tensor (64,)
            torch.tensor(y_value)              # float32 scalar
        )

In [11]:
class OthelloNet(nn.Module):
    def __init__(self):
        super().__init__()
        # flatten (2,64) → 128
        self.flatten = nn.Flatten()
        # trunk: 128 → 512 → 256 → 128
        self.trunk = nn.Sequential(
            nn.Linear(2 * 64, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )
        # policy head: 128 → 64 logits
        self.policy_head = nn.Linear(128, 64)
        # value head: 128 → 64 → 1
        self.value_head = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x, legal_mask):
        # x: (B,2,64), legal_mask unused inside
        B = x.size(0)
        x = self.flatten(x)          # → (B,128)
        feats = self.trunk(x)        # → (B,128)

        logits = self.policy_head(feats)       # → (B,64)

        v = self.value_head(feats)             # → (B,1)
        value = torch.tanh(v).squeeze(1)       # → (B,)

        return logits, value


def train(args):
    NEG_INF = -1e9
    torch.autograd.set_detect_anomaly(True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Dataset & DataLoader ---
    dataset = H5OthelloDataset(args.data)
    loader  = DataLoader(
        dataset,
        batch_size   = args.batch_size,
        shuffle      = True,
        num_workers  = args.num_workers,
        pin_memory   = True
    )

    # --- Model, Optimizer, Scheduler ---
    model = OthelloNet().to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr           = args.lr,
        weight_decay = args.weight_decay
    )
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=args.lr_step, gamma=args.lr_gamma
    )

    best_loss = float('inf')
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)

    # --- Training Loop ---
    for epoch in range(1, args.epochs + 1):
        model.train()
        total_p_loss = 0.0
        total_v_loss = 0.0

        pbar = tqdm(loader, desc=f"Epoch {epoch}/{args.epochs}")
        for x, pi, mask, y in pbar:
            x, pi, mask, y = x.to(device), pi.to(device), mask.to(device), y.to(device)
            # re‑normalize
            pi = pi * mask.float()
            pi = pi / pi.sum(dim=1, keepdim=True).clamp(min=1e-6)

            optimizer.zero_grad()

            logits, v_pred = model(x, mask)
            logits = logits.masked_fill(~mask, NEG_INF)
            log_probs = F.log_softmax(logits, dim=1)

            p_loss = -(pi * log_probs).sum(dim=1).mean()
            v_loss = F.mse_loss(v_pred, y)
            loss   = p_loss + args.value_weight * v_loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            bs = x.size(0)
            total_p_loss += p_loss.item() * bs
            total_v_loss += v_loss.item() * bs

        scheduler.step()
        N = len(dataset)
        avg_p = total_p_loss / N
        avg_v = total_v_loss / N
        avg_total = avg_p + args.value_weight * avg_v

        print(f"→ Epoch {epoch:2d}: Policy={avg_p:.4f}, Value={avg_v:.4f}, Total={avg_total:.4f}")

        if avg_total < best_loss:
            best_loss = avg_total
            torch.save(model.state_dict(), args.save_path)
            print(f"  ↳ Saved best model to {args.save_path}")

    print("✅ Training complete.")

    # --- ONNX Export ---
    if args.onnx_path:
        os.makedirs(os.path.dirname(args.onnx_path), exist_ok=True)
        # now dummy_x is (1,2,64)
        dummy_x    = torch.randn(1, 2, 64, device=device)
        dummy_mask = torch.ones(1, 64, dtype=torch.bool, device=device)
        torch.onnx.export(
            model,
            (dummy_x, dummy_mask),
            args.onnx_path,
            export_params=True,
            opset_version=12,
            do_constant_folding=True,
            input_names=['board', 'legal_mask'],
            output_names=['logits', 'value'],
            dynamic_axes={
                'board':      {0: 'batch_size'},
                'legal_mask': {0: 'batch_size'},
                'logits':     {0: 'batch_size'},
                'value':      {0: 'batch_size'},
            }
        )
        print(f"✅ Exported ONNX model to {args.onnx_path}")

In [6]:
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [12]:
if __name__ == "__main__":

    training_set = '/content/drive/MyDrive/training_samples_sample.h5'

    p = argparse.ArgumentParser(description="Train & export Othello net from HDF5 data")
    p.add_argument("--data",         type=str,   default=training_set,
                   help="Path to HDF5 file containing training groups")
    p.add_argument("--save-path",    type=str,   default="models/othello_h5.pt",
                   help="Where to save PyTorch model weights")
    p.add_argument("--onnx-path",    type=str,   default="models/othello.onnx",
                   help="Where to save ONNX model (set to empty string to skip)")
    p.add_argument("--epochs",       type=int,   default=20)
    p.add_argument("--batch-size",   type=int,   default=128)
    p.add_argument("--num-workers",  type=int,   default=4)
    p.add_argument("--lr",           type=float, default=1e-3)
    p.add_argument("--weight-decay", type=float, default=1e-4)
    p.add_argument("--lr-step",      type=int,   default=10)
    p.add_argument("--lr-gamma",     type=float, default=0.1)
    p.add_argument("--value-weight", type=float, default=1.0,
                   help="λ for value‐loss scaling")
    p.add_argument("--device",       type=str,   default=None,
                   help="Override device (e.g., 'cpu' or 'cuda:0')")

    args, err = p.parse_known_args()
    if err:
      print("Ignored Arg: ", err)
    train(args)

Ignored Arg:  ['-f', '/root/.local/share/jupyter/runtime/kernel-1402e090-c10c-4ec2-b757-9b75bda98e03.json']


Epoch 1/20:   0%|          | 0/2281 [00:00<?, ?it/s]

KeyboardInterrupt: 