# RGI scratchpad

In [1]:
from typing import Generic, Literal, Sequence, Type, Any
import dataclasses
from pathlib import Path
import importlib

import numpy as np

from rgi.core.base import TGameState, TAction, TPlayerState
from rgi.core.game_registry import GAME_REGISTRY
from rgi.core.game_registry import PLAYER_REGISTRY

from rgi.core import trajectory
from rgi.core import game_runner
from rgi.tests import test_utils

# Play game and record trajectory


In [2]:
# Set up game.
from rgi.games.connect4 import connect4
TGameState, TPlayerState, TAction = connect4.Connect4State, Literal[None], connect4.Action

game = connect4.Connect4Game()

In [3]:
# Run simple precomputed trajectory.
importlib.reload(trajectory)
importlib.reload(game_runner)

players = [
    test_utils.PresetPlayer[TGameState, TAction](actions=[2,2,2,2]),
    test_utils.PresetPlayer[TGameState, TAction](actions=[1,3,4,5]),
]

runner = game_runner.GameRunner(game, players, verbose=False)
original_trajectory = runner.run()

# original_trajectory = trajectory.play_game_and_record_trajectory(game, players)
original_trajectory.write('trajectory.npz', allow_pickle=False)
reloaded_trajectory = trajectory.GameTrajectory.read('trajectory.npz', TGameState, TAction, allow_pickle=True)

equality_checker = test_utils.EqualityChecker()
print('reloaded_trajectory_equal =', (reloaded_trajectory_equal := equality_checker.check_equality(original_trajectory, reloaded_trajectory)))
assert reloaded_trajectory_equal


reloaded_trajectory_equal = True


In [9]:
from rgi.players.random_player.random_player import RandomPlayer
random_players = [RandomPlayer[TGameState, TAction]() for _ in range(2)]

random_trajectories = [game_runner.GameRunner(game, random_players, verbose=False).run() for _ in range(1000)]



In [10]:
import os
from rgi.core import archive

trajectory_type = trajectory.GameTrajectory[TGameState, TAction]
trajectory_items = random_trajectories

row_archiver = archive.RowFileArchiver()
row_path = "trajectory_test.rgrows"
row_archiver.write_items(trajectory_items, row_path)
row_archive = row_archiver.read_items(row_path, trajectory_type)

column_path = "trajectory_test.rgcols"
os.remove(column_path)
os.remove(column_path + ".index")

column_archiver = archive.ColumnFileArchiver()
column_archiver.write_sequence(trajectory_type, trajectory_items, column_path)
column_archive = archive.MMapColumnArchive(column_path, trajectory_type)


In [12]:
for i in range(1000):
    assert row_archive[i] == trajectory_items[i]
    assert column_archive[i] == trajectory_items[i]



# Benchmark Row vs Column based lookup

In [13]:
import time
import random
import statistics
from typing import TypeVar, Generic, Sequence

T = TypeVar("T")

@dataclasses.dataclass
class ArchiveBenchmarkResult:
    """Results from benchmarking archive reads."""
    name: str
    sequential_read_ms: list[float]  # milliseconds per item
    random_read_ms: list[float]      # milliseconds per item
    
    def __str__(self) -> str:
        return (
            f"{self.name}:\n"
            f"  Sequential: {statistics.mean(self.sequential_read_ms):.3f}ms/item "
            f"(±{statistics.stdev(self.sequential_read_ms):.3f}ms)\n"
            f"  Random: {statistics.mean(self.random_read_ms):.3f}ms/item "
            f"(±{statistics.stdev(self.random_read_ms):.3f}ms)"
        )

def benchmark_archive(
    archive: archive.Archive[T],
    num_items: int = 1000,
    num_trials: int = 5,
) -> ArchiveBenchmarkResult:
    """Benchmark sequential and random access performance of an archive.
    
    Args:
        archive: Archive to benchmark
        num_items: Number of items to read per trial
        num_trials: Number of trials to run
        
    Returns:
        Benchmark results with timing statistics
    """
    sequential_times: list[float] = []
    random_times: list[float] = []
    
    archive_len = len(archive)
    indices = list(range(min(num_items, archive_len)))
    random_indices = list(range(archive_len))
    
    # Benchmark sequential reads
    for _ in range(num_trials):
        start = time.perf_counter()
        for idx in indices:
            _ = archive[idx]
        elapsed = time.perf_counter() - start
        sequential_times.append((elapsed / num_items) * 1000)  # Convert to ms
        
    # Benchmark random reads
    for _ in range(num_trials):
        random.shuffle(random_indices)
        test_indices = random_indices[:num_items]
        start = time.perf_counter()
        for idx in test_indices:
            _ = archive[idx]
        elapsed = time.perf_counter() - start
        random_times.append((elapsed / num_items) * 1000)  # Convert to ms
    
    return ArchiveBenchmarkResult(
        name=archive.__class__.__name__,
        sequential_read_ms=sequential_times,
        random_read_ms=random_times,
    )

def compare_archives(
    row_archive: archive.Archive[T],
    column_archive: archive.Archive[T],
    num_items: int = 1000,
    num_trials: int = 5,
) -> tuple[ArchiveBenchmarkResult, ArchiveBenchmarkResult]:
    """Compare performance between row and column archives.
    
    Args:
        row_archive: Row-based archive to benchmark
        column_archive: Column-based archive to benchmark
        num_items: Number of items to read per trial
        num_trials: Number of trials to run
        
    Returns:
        Tuple of (row_results, column_results)
    """
    row_results = benchmark_archive(row_archive, num_items, num_trials)
    column_results = benchmark_archive(column_archive, num_items, num_trials)
    return row_results, column_results

In [15]:
# Create archives with same data
#row_archive = MMapRowArchive("data.rgrows", item_type)
#column_archive = MMapColumnArchive("data.rgcols", item_type)

# Run benchmark
row_results, col_results = compare_archives(
    row_archive,
    column_archive,
    num_items=1000,
    num_trials=5
)

print(row_results)
print(col_results)


## Benchmark Results
#
# - RowArchive much faster...
# - ColumnArchive was more fun to write :-)
#
# MMapRowArchive:
#   Sequential: 0.055ms/item (±0.015ms)
#   Random: 0.048ms/item (±0.000ms)
# MMapColumnArchive:
#   Sequential: 0.255ms/item (±0.002ms)
#   Random: 0.255ms/item (±0.001ms)

MMapRowArchive:
  Sequential: 0.045ms/item (±0.001ms)
  Random: 0.044ms/item (±0.000ms)
MMapColumnArchive:
  Sequential: 0.260ms/item (±0.001ms)
  Random: 0.259ms/item (±0.005ms)


In [7]:
import typing
from types import GenericAlias
from rgi.core import archive
archive = importlib.reload(archive)

import os
column_path = "trajectory_test.rgcols"
os.remove(column_path)
os.remove(column_path + ".index")

column_archiver = archive.ColumnFileArchiver()
column_archiver.write_sequence(trajectory_type, trajectory_items, column_path)
column_archive = archive.MMapColumnArchive(column_path, trajectory_type)


In [9]:
assert column_archive[0] == trajectory_items[0]


In [11]:
import typing
from types import GenericAlias
from rgi.core import archive
archive = importlib.reload(archive)

column_archiver = archive.ColumnFileArchiver()
column_path = "trajectory_test.rgcols"
# column_archiver.write_sequence(trajectory_type, trajectory_items, column_path)
# column_archive = archive.MMapColumnArchive(column_path, trajectory_type)

item_type = trajectory_type
items = trajectory_items

def resolve_type_vars(field_type: Any) -> Any:
    """Recursively resolve TypeVars in a type to their concrete types."""
    # Direct TypeVar
    if isinstance(field_type, typing.TypeVar):
        type_var_name = field_type.__name__
        for i, param in enumerate(base_type.__parameters__):
            if param.__name__ == type_var_name:
                return type_args[i]
        raise ValueError(f"Could not find type argument for TypeVar {type_var_name}")
    
    # Generic type with potential TypeVar args
    if origin := typing.get_origin(field_type):
        resolved_args = tuple(resolve_type_vars(arg) for arg in typing.get_args(field_type))
        return origin[resolved_args]
        
    return field_type

yy = []
base_type = typing.get_origin(trajectory_type)
type_args = typing.get_args(trajectory_type)

# Process each field
for field in dataclasses.fields(base_type):
    field_type = resolve_type_vars(field.type)
    
    if not isinstance(field_type, (type, types.GenericAlias)):
        raise ValueError(f"Field {field.name} of type {field_type} is not a valid field type in {field_path}")

    field_key = f"{field_path}.{field.name}"
    field_items = [getattr(item, field.name) for item in items]

    # yy.append(self.to_columns(field_key, field_type, field_items))
    

NameError: name 'types' is not defined

## Archive Prototype

# Memory map pickled row format

In [17]:
import importlib

import rgi.core.archive as archive
importlib.reload(archive)

rfa = archive.RowFileArchiver()
fn = rfa.get_lookup_fn("mydata.bin")

print(fn)
fn.sequence_length = 7
print(fn)
print(fn.sequence_length)

<function RowFileArchiver.get_lookup_fn.<locals>.get_item at 0x795218158900>
<function RowFileArchiver.get_lookup_fn.<locals>.get_item at 0x795218158900>
7


# Memmap Example

In [11]:
import numpy as np
import zipfile
import os

filename = "test_data.npz"

# Clean up any old file
if os.path.exists(filename):
    os.remove(filename)

# 1. Create some sample data
arr = np.arange(10, dtype=np.int64)

# 2. Save it as an uncompressed npz
np.savez(filename, x=arr)

# 3. Check the compression type within the npz file
#    We expect 'ZIP_STORED' (i.e. compression type == 0)
with zipfile.ZipFile(filename, "r") as zf:
    for info in zf.infolist():
        print(f"Inside NPZ: {info.filename} compressed?", info.compress_type != 0)

# 4. Load with mmap_mode
#    NOTE: we keep a reference to the returned object to keep the file handle open
with np.load(filename, mmap_mode="r") as npzfile:
    mmapped_arr = npzfile["x"]
    print("Loaded array type:", type(mmapped_arr))
    print("Array contents:", mmapped_arr)


Inside NPZ: x.npy compressed? False
Loaded array type: <class 'numpy.ndarray'>
Array contents: [0 1 2 3 4 5 6 7 8 9]


### Old Prototype

In [None]:
        
class Batch(Generic[T]):
    """Convenience class to convert a sequence of states & actions into a batch.

    >>> from dataclasses import dataclass
    >>> import torch
    >>> @dataclass
    ... class GameState:
    ...     score: int
    ...     current_player: int
    >>> @dataclass
    ... class BatchGameState(Batch[GameState]):
    ...     score: torch.Tensor
    ...     current_player: torch.Tensor
    >>> states = [GameState(5, 1), GameState(7, 2)]
    >>> batch = BatchGameState.from_sequence(states)
    >>> len(batch)
    2
    >>> batch
    BatchGameState(score=tensor([5, 7]), current_player=tensor([1, 2]))
    >>> batch[0]
    GameState(score=5, current_player=1)
    """

    _unbatch_class: Type[T]

    @classmethod
    def from_sequence(cls: Type[TBatch], items: Sequence[T]) -> TBatch:
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        cls_fields = set(f.name for f in fields(cls))  # type: ignore
        batch_dict = {}
        for field in fields(items[0]):  # type: ignore
            if field.name not in cls_fields:
                continue
            values = [getattr(item, field.name) for item in items]
            # We need to handle both primitive values and torch.Tensors here.
            # torch.tensor(primitive_list) is probably more efficient, but doesn't work for tensors.
            batch_dict[field.name] = torch.stack([torch.tensor(value) for value in values])

        batch = cls(**batch_dict)
        batch._unbatch_class = type(items[0])
        return batch

    def __getitem__(self, index: int) -> T:
        item_dict = {field.name: field.type(getattr(self, field.name)[index]) for field in fields(self)}  # type: ignore
        return self._unbatch_class(**item_dict)

    def __len__(self) -> int:
        return len(getattr(self, fields(self)[0].name))  # type: ignore

In [None]:
class Batchable(Protocol[T]):
    """Protocol to convert single states & actions into torch.Tensor for batching."""

    @staticmethod
    def from_sequence(items: Sequence[T]) -> "Batchable[T]": ...

    def __getitem__(self, index: int) -> T: ...

    def __len__(self) -> int: ...


class Batch(Generic[T]):
    """Convenience class to convert a sequence of states & actions into a batch.

    >>> from dataclasses import dataclass
    >>> import torch
    >>> @dataclass
    ... class GameState:
    ...     score: int
    ...     current_player: int
    >>> @dataclass
    ... class BatchGameState(Batch[GameState]):
    ...     score: torch.Tensor
    ...     current_player: torch.Tensor
    >>> states = [GameState(5, 1), GameState(7, 2)]
    >>> batch = BatchGameState.from_sequence(states)
    >>> len(batch)
    2
    >>> batch
    BatchGameState(score=tensor([5, 7]), current_player=tensor([1, 2]))
    >>> batch[0]
    GameState(score=5, current_player=1)
    """

    _unbatch_class: Type[T]

    @classmethod
    def from_sequence(cls: Type[TBatch], items: Sequence[T]) -> TBatch:
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        cls_fields = set(f.name for f in fields(cls))  # type: ignore
        batch_dict = {}
        for field in fields(items[0]):  # type: ignore
            if field.name not in cls_fields:
                continue
            values = [getattr(item, field.name) for item in items]
            # We need to handle both primitive values and torch.Tensors here.
            # torch.tensor(primitive_list) is probably more efficient, but doesn't work for tensors.
            batch_dict[field.name] = torch.stack([torch.tensor(value) for value in values])

        batch = cls(**batch_dict)
        batch._unbatch_class = type(items[0])
        return batch

    def __getitem__(self, index: int) -> T:
        item_dict = {field.name: field.type(getattr(self, field.name)[index]) for field in fields(self)}  # type: ignore
        return self._unbatch_class(**item_dict)

    def __len__(self) -> int:
        return len(getattr(self, fields(self)[0].name))  # type: ignore

@dataclass
class PrimitiveBatch(Generic[T]):
    """A batch class for primitive types like int, float, etc.

    >>> batch = PrimitiveBatch.from_sequence([2,4,6,8])
    >>> len(batch)
    4
    >>> batch
    PrimitiveBatch(values=tensor([2, 4, 6, 8]))
    >>> batch[0]
    2
    """

    values: torch.Tensor

    @classmethod
    def from_sequence(cls: Type["PrimitiveBatch[T]"], items: Sequence[T]) -> "PrimitiveBatch[T]":
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        return cls(values=torch.tensor(items))

    def __getitem__(self, index: int) -> T:
        return self.values[index].item()  # type: ignore

    def __len__(self) -> int:
        return self.values.shape[0]


In [None]:
from typing import TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Access the original class with type parameters
        orig_class = self.__orig_class__
        # Extract the type arguments
        t_game_state, t_archive_state = orig_class.__args__
        print(f"TGameState type: {t_game_state}")
        print(f"TArchiveState type: {t_archive_state}")

# Example classes
class Foo:
    pass

class Bar:
    pass

# Create an instance with specific types
a = Archive[Foo, Bar]()


In [None]:
from dataclasses import dataclass, field
from typing import Type, TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

@dataclass
class Archive(Generic[TGameState, TArchiveState]):
    _game_state_type: Type[TGameState] = field(init=False, repr=False)
    _archive_state_type: Type[TArchiveState] = field(init=False, repr=False)

    def __init__(self, game_state_type: Type[TGameState], archive_state_type: Type[TArchiveState]):
        self._game_state_type = game_state_type
        self._archive_state_type = archive_state_type
        print(f"TGameState type: {self._game_state_type}")
        print(f"TArchiveState type: {self._archive_state_type}")

# Define Foo and Bar as example dataclasses
@dataclass
class Foo:
    name: str = "Foo example"

@dataclass
class Bar:
    value: int = 42

# Instantiate Archive by passing the types
a = Archive(Foo, Bar)


In [None]:
from dataclasses import dataclass, field
from typing import Type, TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

class Archive(Generic[TGameState, TArchiveState]):

    def __init__(self, game_state_type: Type[TGameState], archive_state_type: Type[TArchiveState]):
        self._game_state_type = game_state_type
        self._archive_state_type = archive_state_type
        print(f"TGameState type: {self._game_state_type}")
        print(f"TArchiveState type: {self._archive_state_type}")

# Define Foo and Bar as example dataclasses
@dataclass
class Foo:
    name: str = "Foo example"

@dataclass
class Bar:
    value: int = 42

# Instantiate Archive by passing the types
a = Archive(Foo, Bar)
