# RGI scratchpad

In [18]:
from typing import Generic, Literal, Sequence, Type, Any
import dataclasses
from pathlib import Path
import importlib

import numpy as np

from rgi.core.base import TGameState, TAction, TPlayerState
from rgi.core.game_registry import GAME_REGISTRY
from rgi.core.game_registry import PLAYER_REGISTRY

from rgi.core import trajectory
from rgi.core import game_runner
from rgi.tests import test_utils

# Play game and record trajectory


In [19]:
# Set up game.
from rgi.games.connect4 import connect4
TGameState, TPlayerState, TAction = connect4.Connect4State, Literal[None], connect4.Action

game = connect4.Connect4Game()

In [20]:
# Run simple precomputed trajectory.
importlib.reload(trajectory)
importlib.reload(game_runner)

players = [
    test_utils.PresetPlayer[TGameState, TAction](actions=[2,2,2,2]),
    test_utils.PresetPlayer[TGameState, TAction](actions=[1,3,4,5]),
]

runner = game_runner.GameRunner(game, players, verbose=False)
original_trajectory = runner.run()

# original_trajectory = trajectory.play_game_and_record_trajectory(game, players)
original_trajectory.write('trajectory.npz', allow_pickle=False)
reloaded_trajectory = trajectory.GameTrajectory.read('trajectory.npz', TGameState, TAction, allow_pickle=True)

equality_checker = test_utils.EqualityChecker()
print('reloaded_trajectory_equal =', (reloaded_trajectory_equal := equality_checker.check_equality(original_trajectory, reloaded_trajectory)))
assert reloaded_trajectory_equal


reloaded_trajectory_equal = True


In [34]:
from rgi.players.random_player.random_player import RandomPlayer
random_players = [RandomPlayer[TGameState, TAction]() for _ in range(2)]
random_runner = game_runner.GameRunner(game, players, verbose=False)
t = random_runner.run()
random_trajectories = [random_runner.run() for _ in range(10)]

#random_trajectories


IndexError: list index out of range

## Archive Prototype

# Memory map pickled row format

In [18]:
import pickle
import numpy as np
import os

def write_pickle_archive(data_filename, index_filename, objects):
    """
    Write 'objects' to data_filename as raw pickled bytes (one after another).
    Then write offsets into a .npy array (index_filename), where:
      offsets[i] = byte offset in data_filename for the i-th object
      offsets[-1] = final size of the data file

    We do NOT store the "count" explicitly; we can infer from offsets array size.
    """
    offsets = [0]

    # 1) Write data
    with open(data_filename, "wb") as fdata:
        for obj in objects:
            payload = pickle.dumps(obj, protocol=4)  # or higher protocol if you prefer
            fdata.write(payload)
            offsets.append(fdata.tell())

    # 2) Convert offsets to a NumPy array and save
    #    We'll use uint64 for large file support
    offsets_array = np.array(offsets, dtype=np.uint64)
    np.save(index_filename, offsets_array)  # creates index_filename (e.g. 'myindex.npy')


import pickle
import numpy as np

def open_pickle_archive(data_filename, index_filename):
    """
    Returns (get_item, count):
      - get_item(i): retrieves i-th object from the memory-mapped data file
      - count: number of objects
    """
    # 1) Load the offsets array
    offsets = np.load(index_filename)  # shape = (count+1,)
    count = len(offsets) - 1

    # 2) Memory-map the data file
    mm = np.memmap(data_filename, mode="r")

    def get_item(i):
        """
        Retrieve the i-th object by slicing [offsets[i]:offsets[i+1]] from the memmap
        then unpickling it.
        """
        if i < 0 or i >= count:
            raise IndexError("Index out of range.")
        start = offsets[i]
        end = offsets[i+1]
        # This is a 'memmap' slice, which is a small ndarray of bytes
        chunk = mm[start:end]
        # Convert to a normal Python bytes object so pickle can read it
        return pickle.loads(chunk.tobytes())

    return get_item, count


if __name__ == "__main__":
    # Some random objects
    data = [
        {"name": "Alice", "info": [1,2,3]},
        "Hello, world!",
        (42, 3.14, {"foo": "bar"}),
        ["some", "list", "of", "strings", 999],
    ]

    # Write them out
    write_pickle_archive("mydata.bin", "mydata_idx.npy", data)

    # Read them back
    get_item, total = open_pickle_archive("mydata.bin", "mydata_idx.npy")
    print("Total items:", total)

    for i in range(total):
        obj = get_item(i)
        print(f"Item {i}:", obj)


Total items: 4
Item 0: {'name': 'Alice', 'info': [1, 2, 3]}
Item 1: Hello, world!
Item 2: (42, 3.14, {'foo': 'bar'})
Item 3: ['some', 'list', 'of', 'strings', 999]


In [17]:
import importlib

import rgi.core.archive as archive
importlib.reload(archive)

rfa = archive.RowFileArchiver()
fn = rfa.get_lookup_fn("mydata.bin")

print(fn)
fn.sequence_length = 7
print(fn)
print(fn.sequence_length)

<function RowFileArchiver.get_lookup_fn.<locals>.get_item at 0x795218158900>
<function RowFileArchiver.get_lookup_fn.<locals>.get_item at 0x795218158900>
7


# Memmap Example

In [11]:
import numpy as np
import zipfile
import os

filename = "test_data.npz"

# Clean up any old file
if os.path.exists(filename):
    os.remove(filename)

# 1. Create some sample data
arr = np.arange(10, dtype=np.int64)

# 2. Save it as an uncompressed npz
np.savez(filename, x=arr)

# 3. Check the compression type within the npz file
#    We expect 'ZIP_STORED' (i.e. compression type == 0)
with zipfile.ZipFile(filename, "r") as zf:
    for info in zf.infolist():
        print(f"Inside NPZ: {info.filename} compressed?", info.compress_type != 0)

# 4. Load with mmap_mode
#    NOTE: we keep a reference to the returned object to keep the file handle open
with np.load(filename, mmap_mode="r") as npzfile:
    mmapped_arr = npzfile["x"]
    print("Loaded array type:", type(mmapped_arr))
    print("Array contents:", mmapped_arr)


Inside NPZ: x.npy compressed? False
Loaded array type: <class 'numpy.ndarray'>
Array contents: [0 1 2 3 4 5 6 7 8 9]


### Old Prototype

In [None]:
from dataclasses import dataclass
from typing import Generic, TypeVar, Optional
import numpy as np
from numpy.typing import NDArray

TGameState = TypeVar("TGameState")  # pylint: disable=invalid-name
TAction = TypeVar("TAction")  # pylint: disable=invalid-name
TArchiveState = TypeVar("TArchiveState")  # pylint: disable=invalid-name
TArchiveAction = TypeVar("TArchiveAction")  # pylint: disable=invalid-name


@dataclass
class Count21State:
    score: int
    current_player: int

Count21Action = int

@dataclass
class Connect4State:
    board: NDArray[np.int8]  # (height, width)
    current_player: int
    winner: Optional[int] = None  # The winner, if the game has ended


@dataclass
class Count21ArchiveState:
    score: NDArray[np.int64]
    current_player: NDArray[np.int8]

Count21ArchiveAction = NDArray[np.int8]


@dataclass
class Archive(Generic[TGameState, TAction, TArchiveState, TArchiveAction]):
    pass

print('done.')

a = Archive[Count21State, Count21Action, Count21ArchiveState, Count21ArchiveAction]()

print(a)


In [None]:
        
class Batch(Generic[T]):
    """Convenience class to convert a sequence of states & actions into a batch.

    >>> from dataclasses import dataclass
    >>> import torch
    >>> @dataclass
    ... class GameState:
    ...     score: int
    ...     current_player: int
    >>> @dataclass
    ... class BatchGameState(Batch[GameState]):
    ...     score: torch.Tensor
    ...     current_player: torch.Tensor
    >>> states = [GameState(5, 1), GameState(7, 2)]
    >>> batch = BatchGameState.from_sequence(states)
    >>> len(batch)
    2
    >>> batch
    BatchGameState(score=tensor([5, 7]), current_player=tensor([1, 2]))
    >>> batch[0]
    GameState(score=5, current_player=1)
    """

    _unbatch_class: Type[T]

    @classmethod
    def from_sequence(cls: Type[TBatch], items: Sequence[T]) -> TBatch:
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        cls_fields = set(f.name for f in fields(cls))  # type: ignore
        batch_dict = {}
        for field in fields(items[0]):  # type: ignore
            if field.name not in cls_fields:
                continue
            values = [getattr(item, field.name) for item in items]
            # We need to handle both primitive values and torch.Tensors here.
            # torch.tensor(primitive_list) is probably more efficient, but doesn't work for tensors.
            batch_dict[field.name] = torch.stack([torch.tensor(value) for value in values])

        batch = cls(**batch_dict)
        batch._unbatch_class = type(items[0])
        return batch

    def __getitem__(self, index: int) -> T:
        item_dict = {field.name: field.type(getattr(self, field.name)[index]) for field in fields(self)}  # type: ignore
        return self._unbatch_class(**item_dict)

    def __len__(self) -> int:
        return len(getattr(self, fields(self)[0].name))  # type: ignore

In [None]:
class Batchable(Protocol[T]):
    """Protocol to convert single states & actions into torch.Tensor for batching."""

    @staticmethod
    def from_sequence(items: Sequence[T]) -> "Batchable[T]": ...

    def __getitem__(self, index: int) -> T: ...

    def __len__(self) -> int: ...


class Batch(Generic[T]):
    """Convenience class to convert a sequence of states & actions into a batch.

    >>> from dataclasses import dataclass
    >>> import torch
    >>> @dataclass
    ... class GameState:
    ...     score: int
    ...     current_player: int
    >>> @dataclass
    ... class BatchGameState(Batch[GameState]):
    ...     score: torch.Tensor
    ...     current_player: torch.Tensor
    >>> states = [GameState(5, 1), GameState(7, 2)]
    >>> batch = BatchGameState.from_sequence(states)
    >>> len(batch)
    2
    >>> batch
    BatchGameState(score=tensor([5, 7]), current_player=tensor([1, 2]))
    >>> batch[0]
    GameState(score=5, current_player=1)
    """

    _unbatch_class: Type[T]

    @classmethod
    def from_sequence(cls: Type[TBatch], items: Sequence[T]) -> TBatch:
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        cls_fields = set(f.name for f in fields(cls))  # type: ignore
        batch_dict = {}
        for field in fields(items[0]):  # type: ignore
            if field.name not in cls_fields:
                continue
            values = [getattr(item, field.name) for item in items]
            # We need to handle both primitive values and torch.Tensors here.
            # torch.tensor(primitive_list) is probably more efficient, but doesn't work for tensors.
            batch_dict[field.name] = torch.stack([torch.tensor(value) for value in values])

        batch = cls(**batch_dict)
        batch._unbatch_class = type(items[0])
        return batch

    def __getitem__(self, index: int) -> T:
        item_dict = {field.name: field.type(getattr(self, field.name)[index]) for field in fields(self)}  # type: ignore
        return self._unbatch_class(**item_dict)

    def __len__(self) -> int:
        return len(getattr(self, fields(self)[0].name))  # type: ignore

@dataclass
class PrimitiveBatch(Generic[T]):
    """A batch class for primitive types like int, float, etc.

    >>> batch = PrimitiveBatch.from_sequence([2,4,6,8])
    >>> len(batch)
    4
    >>> batch
    PrimitiveBatch(values=tensor([2, 4, 6, 8]))
    >>> batch[0]
    2
    """

    values: torch.Tensor

    @classmethod
    def from_sequence(cls: Type["PrimitiveBatch[T]"], items: Sequence[T]) -> "PrimitiveBatch[T]":
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        return cls(values=torch.tensor(items))

    def __getitem__(self, index: int) -> T:
        return self.values[index].item()  # type: ignore

    def __len__(self) -> int:
        return self.values.shape[0]


In [None]:
from typing import TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Access the original class with type parameters
        orig_class = self.__orig_class__
        # Extract the type arguments
        t_game_state, t_archive_state = orig_class.__args__
        print(f"TGameState type: {t_game_state}")
        print(f"TArchiveState type: {t_archive_state}")

# Example classes
class Foo:
    pass

class Bar:
    pass

# Create an instance with specific types
a = Archive[Foo, Bar]()


In [None]:
from dataclasses import dataclass, field
from typing import Type, TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

@dataclass
class Archive(Generic[TGameState, TArchiveState]):
    _game_state_type: Type[TGameState] = field(init=False, repr=False)
    _archive_state_type: Type[TArchiveState] = field(init=False, repr=False)

    def __init__(self, game_state_type: Type[TGameState], archive_state_type: Type[TArchiveState]):
        self._game_state_type = game_state_type
        self._archive_state_type = archive_state_type
        print(f"TGameState type: {self._game_state_type}")
        print(f"TArchiveState type: {self._archive_state_type}")

# Define Foo and Bar as example dataclasses
@dataclass
class Foo:
    name: str = "Foo example"

@dataclass
class Bar:
    value: int = 42

# Instantiate Archive by passing the types
a = Archive(Foo, Bar)


In [None]:
from dataclasses import dataclass, field
from typing import Type, TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

class Archive(Generic[TGameState, TArchiveState]):

    def __init__(self, game_state_type: Type[TGameState], archive_state_type: Type[TArchiveState]):
        self._game_state_type = game_state_type
        self._archive_state_type = archive_state_type
        print(f"TGameState type: {self._game_state_type}")
        print(f"TArchiveState type: {self._archive_state_type}")

# Define Foo and Bar as example dataclasses
@dataclass
class Foo:
    name: str = "Foo example"

@dataclass
class Bar:
    value: int = 42

# Instantiate Archive by passing the types
a = Archive(Foo, Bar)
