In [1]:
from Tokenizer import ChessBoardTokenizer

In [7]:
EMBEDDING_DIM = 128
tokenizer = ChessBoardTokenizer(emb_dim=EMBEDDING_DIM)

# Starting position FEN
fen_start = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"

print(f"Tokenizing FEN: {fen_start}")

# Generate the embedded board
board_emb = tokenizer(fen_start)

print(f"\nOutput Embedded Board Shape: {board_emb.shape}")

Tokenizing FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1

Output Embedded Board Shape: torch.Size([64, 128])


In [40]:
board_emb.shape

torch.Size([8, 8, 128])

In [6]:
from AlphaZeroEncoder import AlphaZeroMoveEncoder

encoder = AlphaZeroMoveEncoder()
print(f"Encoder initialized. Vocabulary Size: {encoder.VOCAB_SIZE}")
print("-" * 35)

# 1. Encode Example (Normal Move)
move_uci_1 = "e2e4"
move_id_1 = encoder.encode(move_uci_1)
print(f"Encoding {move_uci_1} â†’ ID {move_id_1}")

Encoder initialized. Vocabulary Size: 4672
-----------------------------------
Encoding e2e4 â†’ ID 898


In [3]:
import chess

def algebraic_to_uci(board_fen, algebraic_move):
    """
    Convert Stockfish-style algebraic notation (e.g., 'Nf3', 'exd5')
    into UCI notation (e.g., 'g1f3', 'e4d5').
    """
    board = chess.Board(board_fen)
    move = board.parse_san(algebraic_move)  # convert algebraic â†’ Move
    return move.uci()


# Example usage:
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
moves = ["e4", "Nf3", "d4", "c4", ]

for algebraic in moves:
    uci = algebraic_to_uci(fen, algebraic)
    print(f"{algebraic:5} â†’ {uci}")
    print(f"Encoded ID: {encoder.encode(uci)}")
    # Update the board for the next move if needed


e4    â†’ e2e4
Encoded ID: 898
Nf3   â†’ g1f3
Encoded ID: 474
d4    â†’ d2d4
Encoded ID: 825
c4    â†’ c2c4
Encoded ID: 752


## Dataset

In [4]:
import chess
import random
import torch
import wandb
from datasets import Dataset, load_dataset
from tqdm import tqdm

def random_fen():
    board = chess.Board()
    for _ in range(random.randint(0, 20)):
        if board.is_game_over():
            break
        move = random.choice(list(board.legal_moves))
        board.push(move)
    return board.fen()

def legal_moves(fen):
    board = chess.Board(fen)
    return [board.san(m) for m in board.legal_moves]
def create_random_dataset(n=100):
    data = []
    for _ in tqdm(range(n), desc="Generating FENs"):
        fen = random_fen()
        moves = legal_moves(fen)
        for move in moves:
            # moves_str = " ".join(sorted(move))  # ðŸ‘ˆ sort for order-invariance
            data.append({"input": fen, "output": move})
    return Dataset.from_list(data)

dataset = create_random_dataset()

  from .autonotebook import tqdm as notebook_tqdm
Generating FENs: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 100/100 [00:00<00:00, 1994.01it/s]


In [8]:
dataset['input'][0]

'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1'

In [27]:
dataset['output'][0]

'Ng5+'

In [28]:
uci = algebraic_to_uci(dataset['input'][0], "Ng5+")
encoder.encode(uci)

1555

In [12]:
from model import ChessGPT
model = ChessGPT()
logit = model(board_emb)

RuntimeError: The expanded size of the tensor (2) must match the existing size (128) at non-singleton dimension 0.  Target sizes: [2].  Tensor sizes: [128]

In [6]:
logit.shape

torch.Size([1, 4672])

## Training

In [8]:
Inputs = []
Outputs = []
for i in range(len(dataset)):
    Inputs.append(tokenizer(dataset['input'][i]))
    Outputs.append(encoder.encode(algebraic_to_uci(dataset['input'][i], dataset['output'][i])))

In [9]:
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample, label
my_dataset = MyDataset(Inputs, Outputs)

In [10]:
from torch.utils.data import DataLoader

dataloader = DataLoader(my_dataset, batch_size=32, shuffle=True, num_workers=0)
    

In [14]:
import torch

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

num_epochs = 3
for epoch in range(num_epochs):
    total_loss = 0.0

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()               # 1. Clear old gradients

        logits = model(inputs)              # 2. Forward pass
        loss = torch.nn.functional.cross_entropy(logits, targets)

        loss.backward()                     # 3. Backward pass (once per batch)
        optimizer.step()                    # 4. Update model

        total_loss += loss.item()           # Log scalar, not tensor

        # Print progress occasionally
        if (batch_idx + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] | Batch [{batch_idx+1}/{len(dataloader)}] | Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}] finished. Avg Loss: {avg_loss:.4f}")


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [17]:
import chess
import chess.pgn
from io import StringIO
from typing import List

def pgn_to_fen_sequence(pgn_string: str) -> List[str]:
    """
    Converts a PGN (Portable Game Notation) string containing one game
    into a list of FEN (Forsyth-Edwards Notation) strings, representing
    every position in the game.

    Args:
        pgn_string: A string containing the full PGN of a single chess game.

    Returns:
        A list of FEN strings, starting with the initial position (rnbq.../ w).
        Returns an empty list if the PGN cannot be parsed.
    """
    # Use StringIO to treat the PGN string like a file for the chess.pgn parser
    pgn_io = StringIO(pgn_string)
    
    # Load the game from the PGN file-like object
    game = chess.pgn.read_game(pgn_io)
    
    if game is None:
        print("Error: Could not parse the PGN string.")
        return []

    fen_list = []
    
    # Start with the initial position of the game
    board = game.board()
    
    # 1. Add the starting position FEN (usually the standard starting position)
    fen_list.append(board.fen())

    # Iterate through every move in the game's main variation
    # 'node' starts at the root of the game tree
    for move in game.mainline_moves():
        # Apply the move to the board object
        board.push(move)
        
        # Record the resulting FEN string
        fen_list.append(board.fen())
        
    return fen_list

# === Example Usage ===

# A famous short game: Paul Morphy vs. The Duke and Count (Opera Game)
# Note: The result (1-0) is important for the parser.
opera_game_pgn = """
[Event "Paris Opera"]
[Site "Paris FRA"]
[Date "1858.00.00"]
[Round "-"]
[White "Morphy, Paul"]
[Black "Duke Karl / Count Isouard"]
[Result "1-0"]

1. e4 e5 2. Nf3 d6 3. d4 Bg4 4. dxe5 Bxf3 5. Qxf3 dxe5 6. Bc4 Nf6 7. Qb3 Qe7 
8. Nc3 c6 9. Bg5 b5 10. Nxb5 cxb5 11. Bxb5+ Nbd7 12. O-O-O Rd8 13. Rxd7 Rxd7 
14. Rd1 Qe6 15. Bxd7+ Nxd7 16. Qb8+ Nxb8 17. Rd8# 1-0
"""

# Get the list of FENs
fen_sequence = pgn_to_fen_sequence(opera_game_pgn)

print(f"Total positions captured: {len(fen_sequence)}")
print("-" * 30)
print("First 5 FENs:")
for i, fen in enumerate(fen_sequence[:5]):
    print(f"Move {i}: {fen}")
print("...")
print("Last FEN (Checkmate):")
print(fen_sequence[-1])

Total positions captured: 34
------------------------------
First 5 FENs:
Move 0: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
Move 1: rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1
Move 2: rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2
Move 3: rnbqkbnr/pppp1ppp/8/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 2
Move 4: rnbqkbnr/ppp2ppp/3p4/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 0 3
...
Last FEN (Checkmate):
1n1Rkb1r/p4ppp/4q3/4p1B1/4P3/8/PPP2PPP/2K5 b k - 1 17
