In [36]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
from shark.utils import nb_init
nb_init()

INFO | nb_init | Set current dir to chess
INFO | nb_init | You are using Python 3.10.10 (main, Sep 14 2023, 16:59:47) [Clang 14.0.3 (clang-1403.0.22.14.1)]


# Chess environment for Reinforcement Learning

Regardless of the RL Agent we wish to train, we need a Chess Environment with which this Agent can play. See also [this popular libray](https://gymnasium.farama.org/index.html) for further reference.

As this repo is strongly based on TorchRL, we will extend their [base class](https://pytorch.org/rl/reference/generated/torchrl.envs.EnvBase.html) to create this chess environment.

This notebook will show you what that is and how it works.

## Create a chess environment

To create a chess environment, all you have to do is to import the correct class, and initialize an instance of that class.

However, let's also see some additional input parameters we can pass.

In [38]:
from shark.env import ChessEnv

help(ChessEnv.__init__)

Help on function __init__ in module shark.env.chess:

__init__(self, engine_path: str = None, time: float = 5, depth: int = 20, flatten_state: bool = False, play_as: str = 'white', play_vs_engine: bool = True, mate_amplifier: float = 10, softmax: bool = True, worst_reward: float = -1000.0, illegal_amplifier: float = 100, lose_on_illegal_move: bool = True, use_one_hot: bool = True, from_engine: bool = False, **kwargs: Any) -> None
    Args:
        engine_path (str):
            Path to chess engine. This class needs a usable chess engine.
            For example: `stockfish`.
            If not passed, this class will read from the `CHESS_ENGINE_EXECUTABLE` environment variable.
            If not set, an error will be raised.
            Please make sure to install a chess engine like Stockfish, and pass the correct
            installation path here.
    
        time (float, optional):
            Timeout value in seconds for engine.
            When the chess engine is called to va

In [39]:
env = ChessEnv()

Let's reset the state (start a new game), and see what the output is:

In [40]:
state = env.reset()
print(f"Reset state: {state}")

Reset state: TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([8, 8, 13]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)


The state variable is a `TensorDict`, basically a `dict` of `Tensor` objects.

The fields are:
* `done`: whether the game is over or not. Should not be, as we just started one.
* `observation`: the current state, thus the chess board. If playing as "black", then "white" (played by the chess engine) moved first and a move has already happened on the board.

Let's create a random agent, and start playing.

In [41]:
from torchrl.collectors import RandomPolicy

actor = RandomPolicy(env.action_spec)

In [42]:
action = actor(state)

In [43]:
action

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4032]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([8, 8, 13]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

As you can see, each turn there are `4032` possible movements: going from a square on the board to another square on board.

Clearly, not all movements are legal: the starting square may be empty, or occupied by an opponent's piece, or the landing square may not be reachable by the selected piece.

You can also select a random move as follows:

In [44]:
action = env.sample()

In [45]:
action

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4032]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

The above method has the advantage that it will either call the engine to sample a move,
or sample a random move only among the possible legal ones.

We can also use `torchrl` validators to see that the env is correctly set up:

In [46]:
from torchrl.envs import check_env_specs

check_env_specs(env)

2024-02-20 20:23:56,579 [torchrl][INFO] check_env_specs succeeded!


You can also perform a rollout, that is kind of randomly playing for a particular number of steps:

In [47]:
td = env.rollout(3)
td

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4032]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 8, 8, 13]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 8, 8, 13]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, 

In order to be able to play, you can do:

In [48]:
td = env.reset()
while not env.is_game_over():
    action = actor(td) # or env.sample()
    td = env.step(action)

In [54]:
env.board.outcome()

Outcome(termination=<Termination.CHECKMATE: 1>, winner=False)