In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, Union, Callable
import numpy as np


In [None]:
#action-value

In [None]:
class GatingNetwork(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(GatingNetwork, self).__init__()
        self.gate = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        return F.softmax(self.gate(x), dim=-1)


class expert(nn.Module):
    def __init__(self, d_model, dim_feedforward, bias, dropout, activation):
        super().__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.linear1 = nn.Linear(d_model, dim_feedforward,
                              bias=bias, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model,
                              bias=bias, **factory_kwargs)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, x):
        x = self.linear2(self.dropout1(self.activation(self.linear1(x))))
        return self.dropout2(x)


class google_expert(nn.Module):
    def __init__(self, d_model, dim_feedforward, bias=True,):
        super().__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.linar1 = nn.Linear(d_model, dim_feedforward,
                                bias=bias, **factory_kwargs)
        self.linear2 = nn.Linear(
            d_model, dim_feedforward, bias=bias, **factory_kwargs)
        self.linear_out = nn.Linear(
            dim_feedforward, d_model, bias, **factory_kwargs)

    def forward(self, x):
        x1 = self.linear1(x)
        x2 = self.linear2(x)
        x = F.silu(x1)*x2
        x = self.linear_out(x)
        return x


class TransformerEncoderLayer(nn.Module):

    __constants__ = ['norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 bias: bool = True, device=None, dtype=None, num_experts: int = 8, num_experts_per_tok: int = 2) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout,
                                            bias=bias, batch_first=batch_first,
                                            **factory_kwargs)

        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward,
                              bias=bias, **factory_kwargs)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model,
                              bias=bias, **factory_kwargs)

        self.norm_first = norm_first
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps,
                               bias=bias, **factory_kwargs)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps,
                               bias=bias, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        # creating the experts -
        self.experts = nn.ModuleList(
            [expert(d_model, dim_feedforward, bias) for i in range(num_experts)])

        # creating gatingnetwork
        self.gatingNetwork = GatingNetwork(d_model, num_experts)
        self.num_experts_per_tok = num_experts_per_tok
        # Legacy string support for activation function.

        if isinstance(activation, str):
            activation = _get_activation_fn(activation)

        # We can't test self.activation in forward() in TorchScript,
        # so stash some information about it instead.
        if activation is F.relu or isinstance(activation, torch.nn.ReLU):
            self.activation_relu_or_gelu = 1
        elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
            self.activation_relu_or_gelu = 2
        else:
            self.activation_relu_or_gelu = 0
        self.activation = activation

    def __setstate__(self, state):
        super().__setstate__(state)
        if not hasattr(self, 'activation'):
            self.activation = F.relu

    def forward(
            self,
            src: Tensor,
            src_mask: Optional[Tensor] = None,
            src_key_padding_mask: Optional[Tensor] = None,
            is_causal: bool = False) -> Tensor:

        src_key_padding_mask = F._canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=F._none_or_dtype(src_mask),
            other_name="src_mask",
            target_type=src.dtype
        )

        src_mask = F._canonical_mask(
            mask=src_mask,
            mask_name="src_mask",
            other_type=None,
            other_name="",
            target_type=src.dtype,
            check_other=False,
        )

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
        x = src
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), src_mask,
                                   src_key_padding_mask, is_causal=is_causal)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_block(x, src_mask,
                           src_key_padding_mask, is_causal=is_causal))
            x = self.norm2(x + self._ff_block(x))

        return x

    # self-attention block

    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False, is_causal=is_causal)[0]
        return self.dropout1(x)

    # feed forward block
#     def _ff_block(self, x: Tensor , num_experts_per_tok) -> Tensor:
#         gating_scores = F.softmax(self.gatingNetwork(x) , dim = -1)
#         top_scores, top_indices = gating_scores.topk(self.num_experts_per_tok , dim = -1 , sorted = False)

#         top_mask = torch.zeros_like(gating_scores)
#         top_mask = top_mask.scatter(-1, top_indices, 1)
#         output = []

#         for i in range(len(top_mask)):
#             if(top_mask[i] == 1):
#                 output.append(top_scores[i]*self.experts[i](x))

#         result = torch.sum(torch.stack(output) , dim = 0)
#         return result
# #         x = self.linear2(self.dropout(self.activation(self.linear1(x))))
# #         return self.dropout2(x)hat

    def _ff_block(self, x: Tensor) -> Tensor:
        gating_scores = F.softmax(self.gatingNetwork(x), dim=-1)
        top_scores, top_indices = gating_scores.topk(
            self.num_experts_per_tok, dim=-1, sorted=False)

        output = torch.zeros_like(x)

        for i, index in enumerate(top_indices):
            expert_outputs = [top_scores[i][j] * self.experts[index[j]]
                              (x[i].unsqueeze(0)) for j in range(self.num_experts_per_tok)]
            output[i] = sum(expert_outputs)

        return output

In [None]:
sample_fen = 'r2q1rk1/1bp2pbp/1p1p1npB/p1n1p3/P3P3/2NP1N1P/1PP1BPP1/R1Q2RK1 w - - 1 12'
print(tokenize(sample_fen))
print(len(tokenize(sample_fen)))

In [None]:
!pip install python-chess


In [None]:
import chess
import numpy as np
_CHESS_FILE = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
def _compute_all_possible_actions() -> tuple[dict[str, int], dict[int, str]]:
  """Returns two dicts converting moves to actions and actions to moves.

  These dicts contain all possible chess moves.
  """
  all_moves = []

  # First, deal with the normal moves.
  # Note that this includes castling, as it is just a rook or king move from one
  # square to another.
  board = chess.BaseBoard.empty()
  for square in range(64):
    next_squares = []

    # Place the queen and see where it attacks (we don't need to cover the case
    # for a bishop, rook, or pawn because the queen's moves includes all their
    # squares).
    board.set_piece_at(square, chess.Piece.from_symbol('Q'))
    next_squares += board.attacks(square)

    # Place knight and see where it attacks
    board.set_piece_at(square, chess.Piece.from_symbol('N'))
    next_squares += board.attacks(square)
    board.remove_piece_at(square)

    for next_square in next_squares:
      all_moves.append(
          chess.square_name(square) + chess.square_name(next_square)
      )

  # Then deal with promotions.
  # Only look at the last ranks.
  promotion_moves = []
  for rank, next_rank in [('2', '1'), ('7', '8')]:
    for index_file, file in enumerate(_CHESS_FILE):
      # Normal promotions.
      move = f'{file}{rank}{file}{next_rank}'
      promotion_moves += [(move + piece) for piece in ['q', 'r', 'b', 'n']]

      # Capture promotions.
      # Left side.
      if file > 'a':
        next_file = _CHESS_FILE[index_file - 1]
        move = f'{file}{rank}{next_file}{next_rank}'
        promotion_moves += [(move + piece) for piece in ['q', 'r', 'b', 'n']]
      # Right side.
      if file < 'h':
        next_file = _CHESS_FILE[index_file + 1]
        move = f'{file}{rank}{next_file}{next_rank}'
        promotion_moves += [(move + piece) for piece in ['q', 'r', 'b', 'n']]
  all_moves += promotion_moves

  move_to_action, action_to_move = {}, {}
  for action, move in enumerate(all_moves):
    assert move not in move_to_action
    move_to_action[move] = action
    action_to_move[action] = move

  return move_to_action, action_to_move


MOVE_TO_ACTION, ACTION_TO_MOVE = _compute_all_possible_actions()
NUM_ACTIONS = len(MOVE_TO_ACTION)



In [None]:
print(MOVE_TO_ACTION)

In [None]:
import torch.nn.init as init
class cust_embeddings(nn.Module):
    def __init__(self, embedding_dim, emb_init_scale, use_sinosoidal=False, max_timescale=1e4):
        super(cust_embeddings, self).__init__()
        self.embedding_dim = embedding_dim
        self.sequence_length = 79
        self.emb_init_scale = emb_init_scale
        self.max_timescale = max_timescale
        vocab_size = NUM_ACTIONS # same is done in the google repo , idk why though

        if(use_sinosoidal):
            pos_embeddings = self._get_sinusoidal_position_encoding(self.sequence_length, embedding_dim, max_timescale)
            self.register_buffer("pos_embedding", pos_embeddings)
        else:
            self.pos_embeddings = nn.Embedding(num_embeddings=self.sequence_length, embedding_dim=embedding_dim) #79 is the max seq length

        self.embeddings_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.init_weights()
    def init_weights(self):
        init.trunc_normal_(self.embeddings_layer.weight, std=self.emb_init_scale)
    def _get_sinusoidal_position_encoding(self, sequence_length, embedding_dim, max_timescale):
        """
        Creates sinusoidal encodings as in the original Transformer paper.

        For each position pos and dimension i:
          PE(pos, 2i)   = sin(pos / (max_timescale^(2i/embedding_dim)))
          PE(pos, 2i+1) = cos(pos / (max_timescale^(2i/embedding_dim)))

        Returns:
            A tensor of shape (sequence_length, embedding_dim).
        """
        # Create a tensor of positions (shape: [sequence_length, 1])
        position = torch.arange(0, sequence_length, dtype=torch.float32).unsqueeze(1)
        # Compute the scaling term for each even index in the embedding dimensions.
        div_term = torch.exp(
            torch.arange(0, embedding_dim, 2, dtype=torch.float32) * (-math.log(max_timescale) / embedding_dim)
        )
        # Calculate the sinusoidal input by outer product (broadcast multiplication)
        sinusoid_inp = position * div_term.unsqueeze(0)
        pos_encoding = torch.zeros(sequence_length, embedding_dim)
        pos_encoding[:, 0::2] = torch.sin(sinusoid_inp)
        pos_encoding[:, 1::2] = torch.cos(sinusoid_inp)
        return pos_encoding
    def forward(self, x):
        if hasattr(self, "pos_embeddings"):
            pos_embed  = self.pos_embeddings(torch.arange(self.sequence_length, device=x.device))
        else:
            pos_embed = self.pos_embeddings(torch.arange(self.sequence_length))
        embed = self.embeddings_layer(x)*Math.sqrt(self.embedding_dim)
        return pos_embed+embed

In [None]:
def shift_right(sequences: torch.Tensor) -> torch.Tensor:
    """Right-shift the one-hot encoded input by padding on the temporal axis."""    # assuming this is wrong , but bc google thori galat hoga
    bos_array = torch.zeros((sequences.size(0), 1), dtype=torch.uint8)
    # Concatenate the bos_array with sequences along the temporal axis
    padded_sequences = torch.cat([bos_array, sequences], dim=1)
    # Return the padded sequences, excluding the last element along the temporal axis
    return padded_sequences[:, :-1]

In [None]:
class decoder(nn.Module):
    def __init__(self, input_size, output_size, decoder_layernorm ):
        
        super().__init__()
        self.use_layer_norm = decoder_layernorm
        self.layer_norm = nn.LayerNorm(input_size)
        self.linear = nn.Linear(input_size, output_size)
    def forward(self, x):
        if(self.use_layer_norm):
            x = self.layer_norm(x)
        x = self.linear(x)
        return x
    

In [None]:
class final_model(nn.Module):
    def __init__(self , embedding_dim , encoder_layers, output_dim):
        super(final_model, self).__init__()
        self.embeddings = cust_embeddings(embedding_dim)
        self.encoders = nn.ModuleList([TransformerEncoderLayer() for _ in range(encoder_layers)])
        self.decoder = decoder(input_size=embedding_dim, output_size=output_dim, decoder_layernorm=True)
        self.linear = nn.Linear(embedding_dim , output_dim)
        
    def forward (self, x):
        x = self.embeddings(x)
        for layer in self.encoders:
            x = layer(x)
            
        logits = self.decoder(x)
    # Apply log softmax
        return F.log_softmax(logits, dim=-1)
        
        