# RGI scratchpad

In [1]:
from typing import Generic, Literal, Sequence, Type, Any
import dataclasses
from pathlib import Path
import importlib

import numpy as np

from rgi.core.base import TGameState, TAction, TPlayerState
from rgi.core.game_registry import GAME_REGISTRY
from rgi.core.game_registry import PLAYER_REGISTRY

from rgi.core import trajectory
from rgi.core import game_runner
from rgi.tests import test_utils

# Play game and record trajectory


In [2]:
# Set up game.
from rgi.games.connect4 import connect4
TGameState, TPlayerState, TAction = connect4.Connect4State, Literal[None], connect4.Action

game = connect4.Connect4Game()

In [3]:
# Run simple precomputed trajectory.
importlib.reload(trajectory)
importlib.reload(game_runner)

players = [
    test_utils.PresetPlayer[TGameState, TAction](actions=[2,2,2,2]),
    test_utils.PresetPlayer[TGameState, TAction](actions=[1,3,4,5]),
]

runner = game_runner.GameRunner(game, players, verbose=False)
original_trajectory = runner.run()

# original_trajectory = trajectory.play_game_and_record_trajectory(game, players)
original_trajectory.write('trajectory.npz', allow_pickle=False)
reloaded_trajectory = trajectory.GameTrajectory.read('trajectory.npz', TGameState, TAction, allow_pickle=True)

equality_checker = test_utils.EqualityChecker()
print('reloaded_trajectory_equal =', (reloaded_trajectory_equal := equality_checker.check_equality(original_trajectory, reloaded_trajectory)))
assert reloaded_trajectory_equal


reloaded_trajectory_equal = True


## Archive Prototype

In [4]:
import numpy as np
import dataclasses
from typing import Any, Type, TypeVar, Callable

T = TypeVar('T')

def dataclass_with_np_eq(*args: Any, **kwargs: Any) -> Callable[[Type[T]], Type[T]]:
    """
    Decorator that defines a class as a dataclass with numpy-aware equality.
    """
    if args and isinstance(args[0], type):
        raise TypeError(
            "dataclass_with_np_eq must be called with parentheses. "
            "Use @dataclass_with_np_eq() instead of @dataclass_with_np_eq"
        )

    def wrapper(cls: Type[T]) -> Type[T]:
        kwargs_copy = {**kwargs, 'eq': False}
        cls = dataclasses.dataclass(**kwargs_copy)(cls)

        def __eq__(self: T, other: object) -> bool:
            if not isinstance(other, type(self)):
                return False
            for field in dataclasses.fields(cls):  # type: ignore
                self_val = getattr(self, field.name)
                other_val = getattr(other, field.name)
                if isinstance(self_val, np.ndarray):
                    if not np.array_equal(self_val, other_val):
                        return False
                elif self_val != other_val:
                    return False
            return True

        setattr(cls, '__eq__', __eq__)
        return cls

    return wrapper

In [10]:
from rgi.core import archive
importlib.reload(archive)
import typing

# mm = trajectory_archive.MemoryMappedArchive('test_archive', TGameState, TAction)
#mm.add_trajectory(original_trajectory)
#mm.save()

original_primitive = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
primitive_item_type = int

original_str = ['hello', 'world', 'foo', 'bar', 'baz', 'qux', 'quux', 'corge', 'grault', 'garply']
str_item_type = str

original_list = [[10, 20, 30], [40, 50, 60], [70, 80, 90], [100, 110, 120], [130, 140, 150], [160, 170, 180], [190, 200, 210], [220, 230, 240], [250, 260, 270], [280, 290, 300]]
list_item_type = list[int]

original_list_jagged = [[1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24, 25]]
list_jagged_item_type = list[int]

original_nested_list = [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24], [25, 26, 27]], [[28, 29, 30], [31, 32, 33], [34, 35, 36]], [[37, 38, 39], [40, 41, 42], [43, 44, 45]]]
nested_list_item_type = list[list[int]]

original_tuple = [(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), (5, 'e'), (6, 'f'), (7, 'g'), (8, 'h'), (9, 'i'), (10, 'j')]
tuple_item_type = tuple[int, str]

original_array = [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24], [25, 26, 27], [28, 29, 30]])]
array_item_type = np.ndarray[np.int64, np.dtype[np.int64]]


@dataclass_with_np_eq()
class MultiFieldItem:
    a: int
    b: str
    f: float
    t: tuple[int, str]
    x: np.ndarray
    y: list[int]

original_multifield = [
    MultiFieldItem(a=1, b='hello', f=1.0, t=(1, 'a'), x=np.array([1, 2, 3]), y=[1, 2, 3]),
    MultiFieldItem(a=2, b='world', f=2.0, t=(2, 'b'), x=np.array([4, 5, 6]), y=[4, 5, 6]),
    MultiFieldItem(a=3, b='foo', f=3.0, t=(3, 'c'), x=np.array([7, 8, 9]), y=[7, 8, 9]),
    MultiFieldItem(a=4, b='bar', f=4.0, t=(4, 'd'), x=np.array([10, 11, 12]), y=[10, 11, 12]),
    MultiFieldItem(a=5, b='baz', f=5.0, t=(5, 'e'), x=np.array([13, 14, 15]), y=[13, 14, 15]),
    MultiFieldItem(a=6, b='qux', f=6.0, t=(6, 'f'), x=np.array([16, 17, 18]), y=[16, 17, 18]),
    MultiFieldItem(a=7, b='quux', f=7.0, t=(7, 'g'), x=np.array([19, 20, 21]), y=[19, 20, 21]),
    MultiFieldItem(a=8, b='corge', f=8.0, t=(8, 'h'), x=np.array([22, 23, 24]), y=[22, 23, 24]),
    MultiFieldItem(a=9, b='grault', f=9.0, t=(9, 'i'), x=np.array([25, 26, 27]), y=[25, 26, 27]),
    MultiFieldItem(a=10, b='garply', f=10.0, t=(10, 'j'), x=np.array([28, 29, 30]), y=[28, 29, 30])
]
multifield_item_type = MultiFieldItem

@dataclasses.dataclass
class NestedItem:
    t1: MultiFieldItem
    t2: MultiFieldItem
    list_3d: list[list[int]]

original_nested = [
    NestedItem(t1=original_multifield[0], t2=original_multifield[1], list_3d=[[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
    NestedItem(t1=original_multifield[2], t2=original_multifield[3], list_3d=[[10, 11, 12], [13, 14, 15], [16, 17, 18]]),
    NestedItem(t1=original_multifield[4], t2=original_multifield[5], list_3d=[[19, 20, 21], [22, 23, 24], [25, 26, 27]]),
    NestedItem(t1=original_multifield[6], t2=original_multifield[7], list_3d=[[28, 29, 30], [31, 32, 33], [34, 35, 36]]),
    NestedItem(t1=original_multifield[8], t2=original_multifield[9], list_3d=[[37, 38, 39], [40, 41, 42], [43, 44, 45]]),
]
nested_item_type = NestedItem


test_params = [
    # (original_primitive, primitive_item_type),
    # (original_str, str_item_type),
    # (original_list, list_item_type),
    # (original_list_jagged, list_jagged_item_type),
    # (original_nested_list, nested_list_item_type),
    # (original_tuple, tuple_item_type),
    # (original_array, array_item_type),
    (original_multifield, multifield_item_type),
    # (original_nested, nested_item_type),
]

for original, item_type in test_params:
    path = Path(f'{item_type.__name__}.npz')
    serializer = archive.ArchiveSerializer(item_type)
    serializer.save(original, path)

    deserialized_archive = serializer.load_sequence(path)
    deserialized_slice = serializer.load_sequence(path, slice(1, 3))

    if item_type == np.ndarray or typing.get_origin(item_type) == np.ndarray:
        np.testing.assert_array_equal(original, deserialized_archive)    # type: ignore
        np.testing.assert_array_equal(original[1:3], deserialized_slice)  # type: ignore
    else:
        assert original == deserialized_archive, f'{item_type.__name__} failed'
        assert original[1:3] == deserialized_slice, f'{item_type.__name__} failed'
        assert original != original[1:3]


    # mmap = serializer.load_mmap(path)
    # assert original[0] == mmap[0]

    # mmap = serializer.load_mmap(path)
    # assert original[1:3] == mmap[1:3]


# archive.MMapColumnArchive.save(list_based_archive, Path('test_archive.npz'))

# mmapped_archive = archive.MMapColumnArchive(Path('test_archive.npz'), TestItem)
# print(mmapped_archive[0])
# print(mmapped_archive[1])
# print(len(mmapped_archive))


In [None]:
# print(slice(10))
# print(slice(10, 20))
# print(slice(10, 20, 3))

# ss = slice(10, 20)
# print(ss.indices(5))
mmap = serializer.load_mmap(Path('test_archive.npz'))
print(mmap[0])


In [17]:
## Custom column based mmap file type

import numpy as np
import struct
import json

MAGIC = b"RGF"         # 3-byte magic string
VERSION = b"\x01"      # 1-byte version

def write_data_then_metadata(filename, arrays_dict):
    """
    Store multiple arrays in 'filename' with the layout:
      - MAGIC + VERSION (4 bytes)
      - raw array data (back-to-back)
      - metadata block (JSON) at the end
      - 8-byte integer: size of metadata block

    arrays_dict: dict { name_str: np.ndarray }
                 We assume you know the array shapes, dtypes, etc.
    """

    # Open file for writing (truncate if exists)
    with open(filename, "wb") as f:
        # 1) Write a magic header
        f.write(MAGIC + VERSION)
        # Keep track of array metadata
        array_infos = []

        # 2) Write each array's raw data, record offsets
        for name, arr in arrays_dict.items():
            # Align on current offset
            data_offset = f.tell()
            # Write array data
            f.write(arr.tobytes(order="C"))  # Write in C-contiguous order

            # Store info about this array (for the metadata later)
            array_info = {
                "name": name,
                "dtype": str(arr.dtype),   # e.g. 'float64'
                "shape": arr.shape,       # e.g. (10,)
                "offset": data_offset,    # byte offset
            }
            array_infos.append(array_info)

        # 3) Now write the metadata as JSON
        metadata_dict = {"arrays": array_infos}
        metadata_bytes = json.dumps(metadata_dict).encode("utf-8")

        # Write the metadata
        metadata_offset = f.tell()
        f.write(metadata_bytes)

        # 4) Finally, write an 8-byte integer = length of metadata block
        metadata_size = len(metadata_bytes)
        f.write(struct.pack("<Q", metadata_size))  # little-endian 8 bytes
        # done!



def load_data_memmap(filename):
    """
    Reads the single-file "RGF" format created by write_data_then_metadata().
    Returns a dict: { name_str: np.memmap }.
    """
    arrays_dict = {}

    with open(filename, "rb") as f:
        # 1) Read the magic + version
        header = f.read(4)  # first 4 bytes
        if len(header) < 4:
            raise ValueError("File too short, not a valid RGF file.")
        magic_part, version_part = header[:3], header[3:]
        if magic_part != MAGIC:
            raise ValueError("File does not start with RGF magic bytes.")
        if version_part != VERSION:
            raise ValueError(f"Unsupported version: {version_part}")

        # 2) Find the metadata length by reading the last 8 bytes
        file_size = os.fstat(f.fileno()).st_size
        if file_size < 12:
            raise ValueError("File too small to contain header + metadata length.")

        f.seek(file_size - 8)
        metadata_size_bytes = f.read(8)
        metadata_size = struct.unpack("<Q", metadata_size_bytes)[0]

        # 3) Now read the metadata block
        metadata_start = file_size - 8 - metadata_size
        if metadata_start < 4:
            raise ValueError("Metadata overlaps with file header or data region.")

        f.seek(metadata_start)
        metadata_bytes = f.read(metadata_size)
        metadata_str = metadata_bytes.decode("utf-8")
        metadata_dict = json.loads(metadata_str)

        # 4) For each array in metadata, create a memmap
        for info in metadata_dict["arrays"]:
            name = info["name"]
            dtype_str = info["dtype"]
            shape = tuple(info["shape"])
            offset = info["offset"]

            dtype = np.dtype(dtype_str)
            mm = np.memmap(
                filename,
                mode="r",      # read-only
                offset=offset,
                shape=shape,
                dtype=dtype,
                order="C"
            )
            arrays_dict[name] = mm

    return arrays_dict



arrays = {
        "arrA": np.arange(10, dtype=np.float64),
        "arrB": np.random.randn(3, 4).astype(np.float32),
}
write_data_then_metadata("mydata.rgf", arrays)

loaded_arrays = load_data_memmap("mydata.rgf")
for key, mmap_arr in loaded_arrays.items():
    print(f"Name: {key}")
    print("  Type:", type(mmap_arr))
    print("  Dtype, shape:", mmap_arr.dtype, mmap_arr.shape)
    print("  Sample data:", mmap_arr[...])  # or slice


Name: arrA
  Type: <class 'numpy.memmap'>
  Dtype, shape: float64 (10,)
  Sample data: [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
Name: arrB
  Type: <class 'numpy.memmap'>
  Dtype, shape: float32 (3, 4)
  Sample data: [[-0.7570919   0.53201723  1.3183448   0.04894583]
 [ 1.2819728  -0.36975688 -0.05891573  1.1025674 ]
 [-0.5578464  -0.84648496 -1.0492903   1.756556  ]]


# Memory map pickled row format

In [18]:
import pickle
import numpy as np
import os

def write_pickle_archive(data_filename, index_filename, objects):
    """
    Write 'objects' to data_filename as raw pickled bytes (one after another).
    Then write offsets into a .npy array (index_filename), where:
      offsets[i] = byte offset in data_filename for the i-th object
      offsets[-1] = final size of the data file

    We do NOT store the "count" explicitly; we can infer from offsets array size.
    """
    offsets = [0]

    # 1) Write data
    with open(data_filename, "wb") as fdata:
        for obj in objects:
            payload = pickle.dumps(obj, protocol=4)  # or higher protocol if you prefer
            fdata.write(payload)
            offsets.append(fdata.tell())

    # 2) Convert offsets to a NumPy array and save
    #    We'll use uint64 for large file support
    offsets_array = np.array(offsets, dtype=np.uint64)
    np.save(index_filename, offsets_array)  # creates index_filename (e.g. 'myindex.npy')


import pickle
import numpy as np

def open_pickle_archive(data_filename, index_filename):
    """
    Returns (get_item, count):
      - get_item(i): retrieves i-th object from the memory-mapped data file
      - count: number of objects
    """
    # 1) Load the offsets array
    offsets = np.load(index_filename)  # shape = (count+1,)
    count = len(offsets) - 1

    # 2) Memory-map the data file
    mm = np.memmap(data_filename, mode="r")

    def get_item(i):
        """
        Retrieve the i-th object by slicing [offsets[i]:offsets[i+1]] from the memmap
        then unpickling it.
        """
        if i < 0 or i >= count:
            raise IndexError("Index out of range.")
        start = offsets[i]
        end = offsets[i+1]
        # This is a 'memmap' slice, which is a small ndarray of bytes
        chunk = mm[start:end]
        # Convert to a normal Python bytes object so pickle can read it
        return pickle.loads(chunk.tobytes())

    return get_item, count


if __name__ == "__main__":
    # Some random objects
    data = [
        {"name": "Alice", "info": [1,2,3]},
        "Hello, world!",
        (42, 3.14, {"foo": "bar"}),
        ["some", "list", "of", "strings", 999],
    ]

    # Write them out
    write_pickle_archive("mydata.bin", "mydata_idx.npy", data)

    # Read them back
    get_item, total = open_pickle_archive("mydata.bin", "mydata_idx.npy")
    print("Total items:", total)

    for i in range(total):
        obj = get_item(i)
        print(f"Item {i}:", obj)


Total items: 4
Item 0: {'name': 'Alice', 'info': [1, 2, 3]}
Item 1: Hello, world!
Item 2: (42, 3.14, {'foo': 'bar'})
Item 3: ['some', 'list', 'of', 'strings', 999]


# Memmap Example

In [11]:
import numpy as np
import zipfile
import os

filename = "test_data.npz"

# Clean up any old file
if os.path.exists(filename):
    os.remove(filename)

# 1. Create some sample data
arr = np.arange(10, dtype=np.int64)

# 2. Save it as an uncompressed npz
np.savez(filename, x=arr)

# 3. Check the compression type within the npz file
#    We expect 'ZIP_STORED' (i.e. compression type == 0)
with zipfile.ZipFile(filename, "r") as zf:
    for info in zf.infolist():
        print(f"Inside NPZ: {info.filename} compressed?", info.compress_type != 0)

# 4. Load with mmap_mode
#    NOTE: we keep a reference to the returned object to keep the file handle open
with np.load(filename, mmap_mode="r") as npzfile:
    mmapped_arr = npzfile["x"]
    print("Loaded array type:", type(mmapped_arr))
    print("Array contents:", mmapped_arr)


Inside NPZ: x.npy compressed? False
Loaded array type: <class 'numpy.ndarray'>
Array contents: [0 1 2 3 4 5 6 7 8 9]


### Old Prototype

In [None]:
from dataclasses import dataclass
from typing import Generic, TypeVar, Optional
import numpy as np
from numpy.typing import NDArray

TGameState = TypeVar("TGameState")  # pylint: disable=invalid-name
TAction = TypeVar("TAction")  # pylint: disable=invalid-name
TArchiveState = TypeVar("TArchiveState")  # pylint: disable=invalid-name
TArchiveAction = TypeVar("TArchiveAction")  # pylint: disable=invalid-name


@dataclass
class Count21State:
    score: int
    current_player: int

Count21Action = int

@dataclass
class Connect4State:
    board: NDArray[np.int8]  # (height, width)
    current_player: int
    winner: Optional[int] = None  # The winner, if the game has ended


@dataclass
class Count21ArchiveState:
    score: NDArray[np.int64]
    current_player: NDArray[np.int8]

Count21ArchiveAction = NDArray[np.int8]


@dataclass
class Archive(Generic[TGameState, TAction, TArchiveState, TArchiveAction]):
    pass

print('done.')

a = Archive[Count21State, Count21Action, Count21ArchiveState, Count21ArchiveAction]()

print(a)


In [None]:
        
class Batch(Generic[T]):
    """Convenience class to convert a sequence of states & actions into a batch.

    >>> from dataclasses import dataclass
    >>> import torch
    >>> @dataclass
    ... class GameState:
    ...     score: int
    ...     current_player: int
    >>> @dataclass
    ... class BatchGameState(Batch[GameState]):
    ...     score: torch.Tensor
    ...     current_player: torch.Tensor
    >>> states = [GameState(5, 1), GameState(7, 2)]
    >>> batch = BatchGameState.from_sequence(states)
    >>> len(batch)
    2
    >>> batch
    BatchGameState(score=tensor([5, 7]), current_player=tensor([1, 2]))
    >>> batch[0]
    GameState(score=5, current_player=1)
    """

    _unbatch_class: Type[T]

    @classmethod
    def from_sequence(cls: Type[TBatch], items: Sequence[T]) -> TBatch:
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        cls_fields = set(f.name for f in fields(cls))  # type: ignore
        batch_dict = {}
        for field in fields(items[0]):  # type: ignore
            if field.name not in cls_fields:
                continue
            values = [getattr(item, field.name) for item in items]
            # We need to handle both primitive values and torch.Tensors here.
            # torch.tensor(primitive_list) is probably more efficient, but doesn't work for tensors.
            batch_dict[field.name] = torch.stack([torch.tensor(value) for value in values])

        batch = cls(**batch_dict)
        batch._unbatch_class = type(items[0])
        return batch

    def __getitem__(self, index: int) -> T:
        item_dict = {field.name: field.type(getattr(self, field.name)[index]) for field in fields(self)}  # type: ignore
        return self._unbatch_class(**item_dict)

    def __len__(self) -> int:
        return len(getattr(self, fields(self)[0].name))  # type: ignore

In [None]:
class Batchable(Protocol[T]):
    """Protocol to convert single states & actions into torch.Tensor for batching."""

    @staticmethod
    def from_sequence(items: Sequence[T]) -> "Batchable[T]": ...

    def __getitem__(self, index: int) -> T: ...

    def __len__(self) -> int: ...


class Batch(Generic[T]):
    """Convenience class to convert a sequence of states & actions into a batch.

    >>> from dataclasses import dataclass
    >>> import torch
    >>> @dataclass
    ... class GameState:
    ...     score: int
    ...     current_player: int
    >>> @dataclass
    ... class BatchGameState(Batch[GameState]):
    ...     score: torch.Tensor
    ...     current_player: torch.Tensor
    >>> states = [GameState(5, 1), GameState(7, 2)]
    >>> batch = BatchGameState.from_sequence(states)
    >>> len(batch)
    2
    >>> batch
    BatchGameState(score=tensor([5, 7]), current_player=tensor([1, 2]))
    >>> batch[0]
    GameState(score=5, current_player=1)
    """

    _unbatch_class: Type[T]

    @classmethod
    def from_sequence(cls: Type[TBatch], items: Sequence[T]) -> TBatch:
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        cls_fields = set(f.name for f in fields(cls))  # type: ignore
        batch_dict = {}
        for field in fields(items[0]):  # type: ignore
            if field.name not in cls_fields:
                continue
            values = [getattr(item, field.name) for item in items]
            # We need to handle both primitive values and torch.Tensors here.
            # torch.tensor(primitive_list) is probably more efficient, but doesn't work for tensors.
            batch_dict[field.name] = torch.stack([torch.tensor(value) for value in values])

        batch = cls(**batch_dict)
        batch._unbatch_class = type(items[0])
        return batch

    def __getitem__(self, index: int) -> T:
        item_dict = {field.name: field.type(getattr(self, field.name)[index]) for field in fields(self)}  # type: ignore
        return self._unbatch_class(**item_dict)

    def __len__(self) -> int:
        return len(getattr(self, fields(self)[0].name))  # type: ignore

@dataclass
class PrimitiveBatch(Generic[T]):
    """A batch class for primitive types like int, float, etc.

    >>> batch = PrimitiveBatch.from_sequence([2,4,6,8])
    >>> len(batch)
    4
    >>> batch
    PrimitiveBatch(values=tensor([2, 4, 6, 8]))
    >>> batch[0]
    2
    """

    values: torch.Tensor

    @classmethod
    def from_sequence(cls: Type["PrimitiveBatch[T]"], items: Sequence[T]) -> "PrimitiveBatch[T]":
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        return cls(values=torch.tensor(items))

    def __getitem__(self, index: int) -> T:
        return self.values[index].item()  # type: ignore

    def __len__(self) -> int:
        return self.values.shape[0]


In [None]:
from typing import TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Access the original class with type parameters
        orig_class = self.__orig_class__
        # Extract the type arguments
        t_game_state, t_archive_state = orig_class.__args__
        print(f"TGameState type: {t_game_state}")
        print(f"TArchiveState type: {t_archive_state}")

# Example classes
class Foo:
    pass

class Bar:
    pass

# Create an instance with specific types
a = Archive[Foo, Bar]()


In [None]:
from dataclasses import dataclass, field
from typing import Type, TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

@dataclass
class Archive(Generic[TGameState, TArchiveState]):
    _game_state_type: Type[TGameState] = field(init=False, repr=False)
    _archive_state_type: Type[TArchiveState] = field(init=False, repr=False)

    def __init__(self, game_state_type: Type[TGameState], archive_state_type: Type[TArchiveState]):
        self._game_state_type = game_state_type
        self._archive_state_type = archive_state_type
        print(f"TGameState type: {self._game_state_type}")
        print(f"TArchiveState type: {self._archive_state_type}")

# Define Foo and Bar as example dataclasses
@dataclass
class Foo:
    name: str = "Foo example"

@dataclass
class Bar:
    value: int = 42

# Instantiate Archive by passing the types
a = Archive(Foo, Bar)


In [None]:
from dataclasses import dataclass, field
from typing import Type, TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

class Archive(Generic[TGameState, TArchiveState]):

    def __init__(self, game_state_type: Type[TGameState], archive_state_type: Type[TArchiveState]):
        self._game_state_type = game_state_type
        self._archive_state_type = archive_state_type
        print(f"TGameState type: {self._game_state_type}")
        print(f"TArchiveState type: {self._archive_state_type}")

# Define Foo and Bar as example dataclasses
@dataclass
class Foo:
    name: str = "Foo example"

@dataclass
class Bar:
    value: int = 42

# Instantiate Archive by passing the types
a = Archive(Foo, Bar)
