# RGI scratchpad

In [2]:
from typing import Generic, Literal, Sequence, Type, Any
import dataclasses
from pathlib import Path
import importlib

import numpy as np

from rgi.core.base import TGameState, TAction, TPlayerState
from rgi.core.game_registry import GAME_REGISTRY
from rgi.core.game_registry import PLAYER_REGISTRY

from rgi.core import trajectory
from rgi.core import game_runner
from rgi.tests import test_utils

# Play game and record trajectory


In [4]:
# Set up game.
from rgi.games.connect4 import connect4
TGameState, TPlayerState, TAction = connect4.Connect4State, Literal[None], connect4.Action

game = connect4.Connect4Game()

In [5]:
# Run simple precomputed trajectory.
importlib.reload(trajectory)
importlib.reload(game_runner)

players = [
    test_utils.PresetPlayer[TGameState, TAction](actions=[2,2,2,2]),
    test_utils.PresetPlayer[TGameState, TAction](actions=[1,3,4,5]),
]

runner = game_runner.GameRunner(game, players, verbose=False)
original_trajectory = runner.run()

# original_trajectory = trajectory.play_game_and_record_trajectory(game, players)
original_trajectory.write('trajectory.npz', allow_pickle=False)
reloaded_trajectory = trajectory.GameTrajectory.read('trajectory.npz', TGameState, TAction, allow_pickle=True)

equality_checker = test_utils.EqualityChecker()
print('reloaded_trajectory_equal =', (reloaded_trajectory_equal := equality_checker.check_equality(original_trajectory, reloaded_trajectory)))
assert reloaded_trajectory_equal


yoooooooooooooooooooooooooooooooo
### rodo: magic: b'PK\x03\x04-\x00'
### rodo: magic is zip:
### rodo: self._files: ['action_player_ids.npy', 'incremental_rewards.npy', 'final_reward.npy', 'num_players.npy', 'state_board.npy', 'state_current_player.npy', 'state_winner.npy', 'action_.npy']
reloaded_trajectory_equal = True


## Archive Prototype

In [19]:
from rgi.core import archive
importlib.reload(archive)

# mm = trajectory_archive.MemoryMappedArchive('test_archive', TGameState, TAction)
#mm.add_trajectory(original_trajectory)
#mm.save()

@dataclasses.dataclass
class TestItem:
    a: int
    b: str
    t: tuple[int, str]
    x: np.ndarray
    y: list[int]

list_based_archive = archive.ListBasedArchive(TestItem)
list_based_archive.add_item(TestItem(a=1, b='hello', t=(1, 'a'), x=np.array([1, 2, 3]), y=[1, 2, 3]))
list_based_archive.add_item(TestItem(a=2, b='world', t=(2, 'b'), x=np.array([4, 5, 6]), y=[4, 5, 6]))

print(list_based_archive[0])
print(list_based_archive[1])
print(len(list_based_archive))

archive.MMappedArchive.save(list_based_archive, Path('test_archive.npz'))

mmapped_archive = archive.MMappedArchive(Path('test_archive.npz'), TestItem)
print(mmapped_archive[0])
print(mmapped_archive[1])
print(len(mmapped_archive))


TestItem(a=1, b='hello', t=(1, 'a'), x=array([1, 2, 3]), y=[1, 2, 3])
TestItem(a=2, b='world', t=(2, 'b'), x=array([4, 5, 6]), y=[4, 5, 6])
2


KeyError: 'game_states'

### Old Prototype

In [None]:
from dataclasses import dataclass
from typing import Generic, TypeVar, Optional
import numpy as np
from numpy.typing import NDArray

TGameState = TypeVar("TGameState")  # pylint: disable=invalid-name
TAction = TypeVar("TAction")  # pylint: disable=invalid-name
TArchiveState = TypeVar("TArchiveState")  # pylint: disable=invalid-name
TArchiveAction = TypeVar("TArchiveAction")  # pylint: disable=invalid-name


@dataclass
class Count21State:
    score: int
    current_player: int

Count21Action = int

@dataclass
class Connect4State:
    board: NDArray[np.int8]  # (height, width)
    current_player: int
    winner: Optional[int] = None  # The winner, if the game has ended


@dataclass
class Count21ArchiveState:
    score: NDArray[np.int64]
    current_player: NDArray[np.int8]

Count21ArchiveAction = NDArray[np.int8]


@dataclass
class Archive(Generic[TGameState, TAction, TArchiveState, TArchiveAction]):
    pass

print('done.')

a = Archive[Count21State, Count21Action, Count21ArchiveState, Count21ArchiveAction]()

print(a)


In [None]:
class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Get the actual types from the class's __orig_bases__ attribute
        generic_base = self.__class__.__orig_bases__[0]
        self.game_state_type = generic_base.__args__[0]
        self.archive_state_type = generic_base.__args__[1]
        
        # Now we can inspect the fields
        print("Game State fields:", fields(self.game_state_type))
        print("Archive State fields:", fields(self.archive_state_type))
        
        # And get their type hints
        print("Game State types:", get_type_hints(self.game_state_type))
        print("Archive State types:", get_type_hints(self.archive_state_type))

# Test it
a = Archive[Count21State, Count21ArchiveState]()

In [None]:
class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Get the actual types from the class's __orig_bases__ attribute
        generic_base = self.__class__.__orig_bases__[0]
        self.game_state_type = generic_base.__args__[0]  # This is the actual Count21State class
        self.archive_state_type = generic_base.__args__[1]  # This is the actual Count21ArchiveState class
        
        # Initialize empty archive arrays based on the fields
        self.archive = {}
        game_state_fields = dataclasses.fields(self.game_state_type)
        archive_state_fields = dataclasses.fields(self.archive_state_type)
        
        # Map field names to their numpy types
        for game_field, archive_field in zip(game_state_fields, archive_state_fields):
            print(f"Field {game_field.name}: {game_field.type} -> {archive_field.type}")
            # Initialize empty array for this field
            # We'll need to determine the numpy dtype based on the archive field type

    def add(self, state: TGameState) -> None:
        # Convert each field to the appropriate numpy array
        for field in dataclasses.fields(state):
            value = getattr(state, field.name)
            print(f"Adding {field.name}: {value}")

# Test it
a = Archive[Count21State, Count21ArchiveState]()
a.add(Count21State(score=1, current_player=1))

In [None]:

import dataclasses
class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Get the actual types from the class's __orig_bases__ attribute
        generic_base = self.__class__.__orig_bases__[0]
        self.game_state_type = generic_base.__args__[0]
        self.archive_state_type = generic_base.__args__[1]

        self.archive = TArchiveState()

    def add(self, state: TGameState) -> None:
        print(dataclasses.fields(state))
        print(dataclasses.fields(self.archive))


a = Archive[Count21State, Count21ArchiveState]()
a.add(Count21State(score=1, current_player=1))


In [None]:
        
class Batch(Generic[T]):
    """Convenience class to convert a sequence of states & actions into a batch.

    >>> from dataclasses import dataclass
    >>> import torch
    >>> @dataclass
    ... class GameState:
    ...     score: int
    ...     current_player: int
    >>> @dataclass
    ... class BatchGameState(Batch[GameState]):
    ...     score: torch.Tensor
    ...     current_player: torch.Tensor
    >>> states = [GameState(5, 1), GameState(7, 2)]
    >>> batch = BatchGameState.from_sequence(states)
    >>> len(batch)
    2
    >>> batch
    BatchGameState(score=tensor([5, 7]), current_player=tensor([1, 2]))
    >>> batch[0]
    GameState(score=5, current_player=1)
    """

    _unbatch_class: Type[T]

    @classmethod
    def from_sequence(cls: Type[TBatch], items: Sequence[T]) -> TBatch:
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        cls_fields = set(f.name for f in fields(cls))  # type: ignore
        batch_dict = {}
        for field in fields(items[0]):  # type: ignore
            if field.name not in cls_fields:
                continue
            values = [getattr(item, field.name) for item in items]
            # We need to handle both primitive values and torch.Tensors here.
            # torch.tensor(primitive_list) is probably more efficient, but doesn't work for tensors.
            batch_dict[field.name] = torch.stack([torch.tensor(value) for value in values])

        batch = cls(**batch_dict)
        batch._unbatch_class = type(items[0])
        return batch

    def __getitem__(self, index: int) -> T:
        item_dict = {field.name: field.type(getattr(self, field.name)[index]) for field in fields(self)}  # type: ignore
        return self._unbatch_class(**item_dict)

    def __len__(self) -> int:
        return len(getattr(self, fields(self)[0].name))  # type: ignore

In [None]:
class Batchable(Protocol[T]):
    """Protocol to convert single states & actions into torch.Tensor for batching."""

    @staticmethod
    def from_sequence(items: Sequence[T]) -> "Batchable[T]": ...

    def __getitem__(self, index: int) -> T: ...

    def __len__(self) -> int: ...


class Batch(Generic[T]):
    """Convenience class to convert a sequence of states & actions into a batch.

    >>> from dataclasses import dataclass
    >>> import torch
    >>> @dataclass
    ... class GameState:
    ...     score: int
    ...     current_player: int
    >>> @dataclass
    ... class BatchGameState(Batch[GameState]):
    ...     score: torch.Tensor
    ...     current_player: torch.Tensor
    >>> states = [GameState(5, 1), GameState(7, 2)]
    >>> batch = BatchGameState.from_sequence(states)
    >>> len(batch)
    2
    >>> batch
    BatchGameState(score=tensor([5, 7]), current_player=tensor([1, 2]))
    >>> batch[0]
    GameState(score=5, current_player=1)
    """

    _unbatch_class: Type[T]

    @classmethod
    def from_sequence(cls: Type[TBatch], items: Sequence[T]) -> TBatch:
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        cls_fields = set(f.name for f in fields(cls))  # type: ignore
        batch_dict = {}
        for field in fields(items[0]):  # type: ignore
            if field.name not in cls_fields:
                continue
            values = [getattr(item, field.name) for item in items]
            # We need to handle both primitive values and torch.Tensors here.
            # torch.tensor(primitive_list) is probably more efficient, but doesn't work for tensors.
            batch_dict[field.name] = torch.stack([torch.tensor(value) for value in values])

        batch = cls(**batch_dict)
        batch._unbatch_class = type(items[0])
        return batch

    def __getitem__(self, index: int) -> T:
        item_dict = {field.name: field.type(getattr(self, field.name)[index]) for field in fields(self)}  # type: ignore
        return self._unbatch_class(**item_dict)

    def __len__(self) -> int:
        return len(getattr(self, fields(self)[0].name))  # type: ignore

@dataclass
class PrimitiveBatch(Generic[T]):
    """A batch class for primitive types like int, float, etc.

    >>> batch = PrimitiveBatch.from_sequence([2,4,6,8])
    >>> len(batch)
    4
    >>> batch
    PrimitiveBatch(values=tensor([2, 4, 6, 8]))
    >>> batch[0]
    2
    """

    values: torch.Tensor

    @classmethod
    def from_sequence(cls: Type["PrimitiveBatch[T]"], items: Sequence[T]) -> "PrimitiveBatch[T]":
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        return cls(values=torch.tensor(items))

    def __getitem__(self, index: int) -> T:
        return self.values[index].item()  # type: ignore

    def __len__(self) -> int:
        return self.values.shape[0]


In [None]:
from typing import TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Access the original class with type parameters
        orig_class = self.__orig_class__
        # Extract the type arguments
        t_game_state, t_archive_state = orig_class.__args__
        print(f"TGameState type: {t_game_state}")
        print(f"TArchiveState type: {t_archive_state}")

# Example classes
class Foo:
    pass

class Bar:
    pass

# Create an instance with specific types
a = Archive[Foo, Bar]()


In [None]:
from dataclasses import dataclass, field
from typing import Type, TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

@dataclass
class Archive(Generic[TGameState, TArchiveState]):
    _game_state_type: Type[TGameState] = field(init=False, repr=False)
    _archive_state_type: Type[TArchiveState] = field(init=False, repr=False)

    def __init__(self, game_state_type: Type[TGameState], archive_state_type: Type[TArchiveState]):
        self._game_state_type = game_state_type
        self._archive_state_type = archive_state_type
        print(f"TGameState type: {self._game_state_type}")
        print(f"TArchiveState type: {self._archive_state_type}")

# Define Foo and Bar as example dataclasses
@dataclass
class Foo:
    name: str = "Foo example"

@dataclass
class Bar:
    value: int = 42

# Instantiate Archive by passing the types
a = Archive(Foo, Bar)


In [None]:
from dataclasses import dataclass, field
from typing import Type, TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

class Archive(Generic[TGameState, TArchiveState]):

    def __init__(self, game_state_type: Type[TGameState], archive_state_type: Type[TArchiveState]):
        self._game_state_type = game_state_type
        self._archive_state_type = archive_state_type
        print(f"TGameState type: {self._game_state_type}")
        print(f"TArchiveState type: {self._archive_state_type}")

# Define Foo and Bar as example dataclasses
@dataclass
class Foo:
    name: str = "Foo example"

@dataclass
class Bar:
    value: int = 42

# Instantiate Archive by passing the types
a = Archive(Foo, Bar)
