# Basics

In [1]:
import os
import sys
sys.path.append('../..')

In [2]:
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import jax

# Existing code

In [3]:
import searchless_chess.src.transformer as transformer
import searchless_chess.src.utils as utils
import searchless_chess.src.training_utils as training_utils
import searchless_chess.src.engines.neural_engines as neural_engines
import searchless_chess.src.tokenizer as tokenizer


policy = 'action_value'
num_layers = 16
embedding_dim = 1024
num_heads = 8
num_return_buckets = 128
output_size = num_return_buckets

predictor_config = transformer.TransformerConfig(
    vocab_size=utils.NUM_ACTIONS,
    output_size=output_size,
    pos_encodings=transformer.PositionalEncodings.LEARNED,
    max_sequence_length=tokenizer.SEQUENCE_LENGTH + 2,
    num_heads=num_heads,
    num_layers=num_layers,
    embedding_dim=embedding_dim,
    apply_post_ln=True,
    apply_qk_layernorm=False,
    use_causal_mask=False,
)

predictor = transformer.build_transformer_predictor(config=predictor_config)

_, return_buckets_values = utils.get_uniform_buckets_edges_values(
      num_return_buckets
)

params = training_utils.load_parameters(
    checkpoint_dir='/home/mhamza/searchless_chess/checkpoints/270M',
    params=predictor.initial_params(
        rng=jrandom.PRNGKey(1),
        targets=np.ones((1, 1), dtype=np.uint32),
    ),
    step=6_400_000,
)

engine = neural_engines.ENGINE_FROM_POLICY[policy](
  return_buckets_values=return_buckets_values,
  predict_fn=neural_engines.wrap_predict_fn(
    predictor=predictor,
    params=params,
    batch_size=1,
  ),
)

2024-11-06 21:40:47.549131: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


KeyboardInterrupt: 

In [13]:
def print_transformer_blueprint(config):
    print(f"Model Blueprint:")
    print(f"Number of layers: {config.num_layers}")
    print(f"Embedding dimension: {config.embedding_dim}")
    print(f"Number of heads: {config.num_heads}")
    print(f"Output size: {config.output_size}")
    print(f"Vocabulary size: {config.vocab_size}")
    print(f"Max sequence length: {config.max_sequence_length}")
    print(f"Positional encodings: {config.pos_encodings}")
    print(f"Apply post layer normalization: {config.apply_post_ln}")
    print(f"Apply QK layer normalization: {config.apply_qk_layernorm}")
    print(f"Use causal mask: {config.use_causal_mask}")
    print(f"Other configurations can be added here if needed.")

# Now call this function to print out the details
print_transformer_blueprint(predictor_config)

Model Blueprint:
Number of layers: 16
Embedding dimension: 1024
Number of heads: 8
Output size: 128
Vocabulary size: 1968
Max sequence length: 79
Positional encodings: PositionalEncodings.LEARNED
Apply post layer normalization: True
Apply QK layer normalization: False
Use causal mask: False
Other configurations can be added here if needed.


In [18]:
def flatten_dict(d, parent_key='', sep='/'):
    """
    Custom function to flatten a nested dictionary.
    """
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def print_transformer_layers(predictor, config):
    # Initialize the model by using the `initial_params` function from the Predictor object
    rng = jax.random.PRNGKey(0)

    # Create dummy input for initialization
    dummy_input = np.ones((1, tokenizer.SEQUENCE_LENGTH + 2), dtype=np.uint32)

    # Call the `initial_params` function on the `predictor` object
    params = predictor.initial_params(rng, dummy_input)

    # Print model structure (flattened dictionary to see all layers)
    flattened_params = flatten_dict(params)

    print("Transformer Model Layers:")
    for layer_name, param_value in flattened_params.items():
        print(f"Layer: {layer_name}, Shape: {param_value.shape}")

# Now call this function to print out the details of your predictor model
print_transformer_layers(predictor, predictor_config)

Transformer Model Layers:
Layer: embed/embeddings, Shape: (1968, 1024)
Layer: embed_1/embeddings, Shape: (79, 1024)
Layer: layer_norm/scale, Shape: (1024,)
Layer: layer_norm/offset, Shape: (1024,)
Layer: multi_head_dot_product_attention/linear/w, Shape: (1024, 1024)
Layer: multi_head_dot_product_attention/linear_1/w, Shape: (1024, 1024)
Layer: multi_head_dot_product_attention/linear_2/w, Shape: (1024, 1024)
Layer: multi_head_dot_product_attention/linear_3/w, Shape: (1024, 1024)
Layer: layer_norm_1/scale, Shape: (1024,)
Layer: layer_norm_1/offset, Shape: (1024,)
Layer: linear/w, Shape: (1024, 4096)
Layer: linear_1/w, Shape: (1024, 4096)
Layer: linear_2/w, Shape: (4096, 1024)
Layer: layer_norm_2/scale, Shape: (1024,)
Layer: layer_norm_2/offset, Shape: (1024,)
Layer: multi_head_dot_product_attention_1/linear/w, Shape: (1024, 1024)
Layer: multi_head_dot_product_attention_1/linear_1/w, Shape: (1024, 1024)
Layer: multi_head_dot_product_attention_1/linear_2/w, Shape: (1024, 1024)
Layer: multi

In [46]:
# def transformer_decoder_with_intermediate(
#     targets: jax.Array,
#     config: TransformerConfig,
#     capture_intermediates: bool = False,
# ) -> tuple:
#     """Modified transformer decoder to return intermediate outputs.

#     Args:
#         targets: The integer target values, shape [B, T].
#         config: The config to use for the transformer.

#     Returns:
#         A tuple of (final_logits, layer_outputs), where `final_logits` is the
#         output of the model, and `layer_outputs` is a list of the token
#         transformations after each layer.
#     """
#     # Right shift the targets to get the inputs (the first token is now a 0).
#     inputs = shift_right(targets)

#     # Embeds the inputs and adds positional encodings.
#     embeddings = embed_sequences(inputs, config)
#     h = embeddings
    
#     # Initialize a list to collect outputs if intermediates are to be captured
#     layer_outputs = [h] if capture_intermediates else None

#     # Loop through each layer, storing the output after each transformation
#     for _ in range(config.num_layers):
#         attention_input = layer_norm(h)
#         attention = _attention_block(attention_input, config)
#         h += attention

#         mlp_input = layer_norm(h)
#         mlp_output = _mlp_block(mlp_input, config)
#         h += mlp_output

#         # Append the current output to layer_outputs if capturing intermediates
#         if capture_intermediates:
#             layer_outputs.append(h)

#     if config.apply_post_ln:
#         h = layer_norm(h)
#     logits = hk.Linear(config.output_size)(h)
#     final_logits = jnn.log_softmax(logits, axis=-1)

#     # Stack layer_outputs along a new axis if capturing intermediates
#     if capture_intermediates:
#         layer_outputs = jnp.stack(layer_outputs, axis=0)

#     return (final_logits, layer_outputs) if capture_intermediates else final_logits

In [42]:
# import csv

# def save_token_transformations_to_csv(predictor, params, input_sequences, output_csv_path):
#     """Runs the model and saves token transformations after each layer to a CSV file.

#     Args:
#         predictor: The transformer model predictor.
#         params: Model parameters.
#         input_sequences: Array of input sequences to process.
#         output_csv_path: Path to the CSV file where results will be saved.
#     """
#     # Run the model to get the intermediate outputs
#     _, layer_outputs = predictor.predict(params, None, input_sequences)

#     # Reshape and prepare data for saving to CSV
#     batch_size, seq_length, embedding_dim = input_sequences.shape[0], input_sequences.shape[1], layer_outputs[0].shape[-1]
    
#     # Flatten layer outputs to a format suitable for CSV
#     flattened_outputs = []
#     for layer_idx, layer_output in enumerate(layer_outputs):
#         for batch_idx in range(batch_size):
#             for token_idx in range(seq_length):
#                 flattened_row = [layer_idx, batch_idx, token_idx]
#                 flattened_row.extend(layer_output[batch_idx, token_idx, :].tolist())
#                 flattened_outputs.append(flattened_row)
    
#     # Save to CSV
#     with open(output_csv_path, 'w', newline='') as csvfile:
#         csv_writer = csv.writer(csvfile)
#         header = ["Layer", "Batch", "Token"] + [f"Dim_{i}" for i in range(embedding_dim)]
#         csv_writer.writerow(header)
#         csv_writer.writerows(flattened_outputs)

#     print(f"Token transformations saved to {output_csv_path}")

In [None]:
# # Example usage
# output_csv_path = 'token_transformations.csv'
# input_sequences = np.ones((8, tokenizer.SEQUENCE_LENGTH + 2), dtype=np.uint32)  # Example batch of sequences

# predictor = build_transformer_predictor_with_intermediate(predictor_config)
# params = predictor.initial_params(jax.random.PRNGKey(0), input_sequences)

# save_token_transformations_to_csv(predictor, params, input_sequences, output_csv_path)

# My Code

In [3]:
from searchless_chess.src.transformer import TransformerConfig, shift_right, embed_sequences, layer_norm, _attention_block, _mlp_block
import haiku as hk
import jax.nn as jnn
from searchless_chess.src import constants
import functools

In [4]:
def transformer_decoder_with_intermediate(
    targets: jax.Array,
    config: TransformerConfig,
) -> jax.Array:
  """Returns the transformer decoder output, shape [B, T, V].

  Follows the LLaMa architecture:
  https://github.com/facebookresearch/llama/blob/main/llama/model.py
  Main changes to the original Transformer decoder:
  - Using gating in the MLP block, with SwiGLU activation function.
  - Using normalization before the attention and MLP blocks.

  Args:
    targets: The integer target values, shape [B, T].
    config: The config to use for the transformer.
  """
  # Right shift the targets to get the inputs (the first token is now a 0).
  # import ipdb; ipdb.set_trace()
  inputs = shift_right(targets)
  # print(f'inputs.shape: {inputs.shape}')

  # Embeds the inputs and adds positional encodings.
  embeddings = embed_sequences(inputs, config)
  # print(f'embeddings.shape: {embeddings.shape}')
  h = embeddings # [B, T, D]
  
  # List to store the outputs of each layer
  # expand dims of the h in the axis 0
  layer_outputs = []

  for _ in range(config.num_layers):
    # record the output of each layer
    layer_outputs.append(jnp.expand_dims(h, axis=1)) # [B, 1, T, D]
    # print('h.shape:', h.shape)
    
    attention_input = layer_norm(h)
    attention = _attention_block(attention_input, config)
    h += attention

    mlp_input = layer_norm(h)
    mlp_output = _mlp_block(mlp_input, config)
    h += mlp_output

  if config.apply_post_ln:
    h = layer_norm(h)
  layer_outputs.append(jnp.expand_dims(h, axis=1)) # [B, 1, T, D]
  logits = hk.Linear(config.output_size)(h)
  # print(f'logits.shape: {logits.shape}')

  # print(f'layer_outputs[0].shape: {layer_outputs[0].shape}')
  layer_outputs = jnp.concatenate(layer_outputs, axis=1)
  # print(f'layer_outputs.shape: {layer_outputs.shape}')
  assert len(layer_outputs.shape) == 4, f"Expected 4D output, got {layer_outputs.shape}"

  probs = jnn.log_softmax(logits, axis=-1)
  return layer_outputs, probs


def build_transformer_predictor_with_intermediate(
    config: TransformerConfig,
) -> constants.Predictor:
  """Returns a transformer predictor."""
  model = hk.transform(functools.partial(transformer_decoder_with_intermediate, config=config))
  return constants.Predictor(initial_params=model.init, predict=model.apply)


In [5]:
import chess
import haiku as hk
import jax
import jax.nn as jnn
import numpy as np
import scipy.special

from searchless_chess.src import constants
from searchless_chess.src import tokenizer
from searchless_chess.src import utils
from searchless_chess.src.engines import engine

from searchless_chess.src.engines.neural_engines import NeuralEngine

In [6]:
def _update_scores_with_repetitions(
    board: chess.Board,
    scores: np.ndarray,
) -> None:
  """Updates the win-probabilities for a board given possible repetitions."""
  sorted_legal_moves = engine.get_ordered_legal_moves(board)
  for i, move in enumerate(sorted_legal_moves):
    board.push(move)
    # If the move results in a draw, associate 50% win prob to it.
    if board.is_fivefold_repetition() or board.can_claim_threefold_repetition():
      scores[i] = 0.5
    board.pop()

class ActionValueDebugEngine(NeuralEngine):
  """Neural engine using a function P(r | s, a)."""

  def analyse(self, board: chess.Board):
    """Returns buckets log-probs for each action, and FEN."""
    # Tokenize the legal actions.
    sorted_legal_moves = engine.get_ordered_legal_moves(board)
    legal_actions = [utils.MOVE_TO_ACTION[x.uci()] for x in sorted_legal_moves]
    legal_actions = np.array(legal_actions, dtype=np.int32)
    legal_actions = np.expand_dims(legal_actions, axis=-1)
    # Tokenize the return buckets.
    dummy_return_buckets = np.zeros((len(legal_actions), 1), dtype=np.int32)
    # Tokenize the board.
    tokenized_fen = tokenizer.tokenize(board.fen()).astype(np.int32)
    sequences = np.stack([tokenized_fen] * len(legal_actions))
    # Create the sequences.
    sequences = np.concatenate(
        [sequences, legal_actions, dummy_return_buckets],
        axis=1,
    ) # [(M)oves x (S)equence Length]
    layer_outputs, log_probs = self.predict_fn(sequences) # [M x L x S x E], [M x S x V]
    return {
      'layer_outputs': layer_outputs[:, :, -1], # [M x L x E]
      'log_probs': log_probs[:, -1], # [M x V]
      'fen': board.fen(),
    }
    # return {'log_probs': self.predict_fn(sequences)[:, -1], 'fen': board.fen()}

  def play(self, board: chess.Board):
    analysis = self.analyse(board)
    return_buckets_log_probs = self.analyse(board)['log_probs']
    return_buckets_probs = np.exp(return_buckets_log_probs)
    win_probs = np.inner(return_buckets_probs, self._return_buckets_values)
    _update_scores_with_repetitions(board, win_probs)
    sorted_legal_moves = engine.get_ordered_legal_moves(board)
    if self.temperature is not None:
      probs = scipy.special.softmax(win_probs / self.temperature, axis=-1)
      # return self._rng.choice(sorted_legal_moves, p=probs)
      best_index = self._rng.choice(np.arange(len(sorted_legal_moves)), p=probs)
    else:
      best_index = np.argmax(win_probs)
      # return sorted_legal_moves[best_index]
    return sorted_legal_moves[best_index], analysis['layer_outputs'][best_index] # .., [L x E]

In [7]:
def my_wrap_predict_fn(
    predictor: constants.Predictor,
    params: hk.Params,
    batch_size: int = 32,
):
  """Returns a simple prediction function from a predictor and parameters.

  Args:
    predictor: Used to predict outputs.
    params: Neural network parameters.
    batch_size: How many sequences to pass to the predictor at once.
  """
  jitted_predict_fn = jax.jit(predictor.predict)

  def fixed_predict_fn(sequences: np.ndarray) -> np.ndarray:
    """Wrapper around the predictor `predict` function."""
    assert sequences.shape[0] == batch_size
    return jitted_predict_fn(
        params=params,
        targets=sequences,
        rng=None,
    )

  def predict_fn(sequences: np.ndarray) -> np.ndarray:
    """Wrapper to collate batches of sequences of fixed size."""
    # sequences: [M x S]
    # import ipdb; ipdb.set_trace()
    remainder = -len(sequences) % batch_size
    padded = np.pad(sequences, ((0, remainder), (0, 0))) # [(M + R) x S]
    sequences_split = np.split(padded, len(padded) // batch_size) # [(M + R) / B x B x S]
    all_layer_outputs, all_probs = [], []
    for sub_sequences in sequences_split:
      # sub_sequences: [B x S]
      layer_outputs, probs = fixed_predict_fn(sub_sequences) # layer_outputs: [B x L x S x E], probs: [B x S x V]
      all_layer_outputs.append(layer_outputs)
      all_probs.append(probs)
    layer_outputs = np.concatenate(all_layer_outputs, axis=0) # [(M + R) x L x S x E]
    probs = np.concatenate(all_probs, axis=0) # [(M + R) x S x V]
    # assert len(outputs) == len(padded)
    # assert len(layer_outputs) == len(padded) #TODO: Skipping this for now
    assert len(probs) == len(padded) 
    assert len(layer_outputs) == len(padded)
    # return outputs[: len(sequences)]  # Crop the padded sequences.
    # return layer_outputs[: len(sequences)], probs[: len(sequences)] #TODO: Skipping this for now
    return layer_outputs[: len(sequences)], probs[: len(sequences)] # [M x L x S x E], [M x S x V]

  return predict_fn


In [8]:
import searchless_chess.src.transformer as transformer
import searchless_chess.src.utils as utils
import searchless_chess.src.training_utils as training_utils
import searchless_chess.src.engines.neural_engines as neural_engines
import searchless_chess.src.tokenizer as tokenizer


policy = 'action_value'
num_layers = 16
embedding_dim = 1024
num_heads = 8
num_return_buckets = 128
output_size = num_return_buckets

predictor_config = transformer.TransformerConfig(
    vocab_size=utils.NUM_ACTIONS,
    output_size=output_size,
    pos_encodings=transformer.PositionalEncodings.LEARNED,
    max_sequence_length=tokenizer.SEQUENCE_LENGTH + 2,
    num_heads=num_heads,
    num_layers=num_layers,
    embedding_dim=embedding_dim,
    apply_post_ln=True,
    apply_qk_layernorm=False,
    use_causal_mask=False,
)

predictor = build_transformer_predictor_with_intermediate(config=predictor_config)

_, return_buckets_values = utils.get_uniform_buckets_edges_values(
      num_return_buckets
)

params = training_utils.load_parameters(
    checkpoint_dir='/home/mhamza/searchless_chess/checkpoints/270M',
    params=predictor.initial_params(
        rng=jrandom.PRNGKey(1),
        targets=np.ones((1, 1), dtype=np.uint32),
    ),
    step=6_400_000,
)

# engine = neural_engines.ENGINE_FROM_POLICY[policy](
#   return_buckets_values=return_buckets_values,
#   predict_fn=neural_engines.wrap_predict_fn(
#     predictor=predictor,
#     params=params,
#     batch_size=1,
#   ),
# )
play_engine = ActionValueDebugEngine(
  return_buckets_values=return_buckets_values,
  predict_fn=my_wrap_predict_fn(
    predictor=predictor,
    params=params,
    batch_size=1,
  ), 
)

In [9]:
# @title Play a move with the agent
import chess
board = chess.Board()
outputs = play_engine.play(board)
outputs

(20, 17, 79, 1024)
(20, 17, 79, 1024)


(Move.from_uci('e2e4'),
 array([[ 1.8495966e+00,  5.8102436e+00,  2.9075146e-04, ...,
          1.6825292e+00,  7.2716112e+00,  2.5959897e-01],
        [ 2.8245749e+00,  8.1319637e+00, -2.2196240e+00, ...,
         -3.6888738e+00,  6.8539166e+00,  2.7591591e+00],
        [ 4.2268257e+00,  1.2340109e+01, -2.7381961e+00, ...,
         -9.1484380e-01,  1.6686024e+01,  4.9675984e+00],
        ...,
        [-1.4124680e+00, -3.0761728e+02,  1.5521661e+01, ...,
         -1.8162535e+02, -7.5763474e+01, -4.1706543e+01],
        [-6.3728233e+01, -2.8500610e+02,  2.9011784e+01, ...,
         -1.1035236e+02, -5.6911602e+01, -2.2329647e+01],
        [-8.4748864e-04, -5.3715512e-02,  1.7183146e-04, ...,
          1.3432965e-03, -1.5181114e-03,  3.6411311e-03]], dtype=float32))

In [49]:
outputs[1].shape

(17, 1024)