# RGI scratchpad

In [4]:
from typing import Generic, Literal, Sequence, Type, Any
import dataclasses
from pathlib import Path
import importlib

import numpy as np

from rgi.core.base import TGameState, TAction, TPlayerState
from rgi.core.game_registry import GAME_REGISTRY
from rgi.core.game_registry import PLAYER_REGISTRY

from rgi.core import trajectory
from rgi.core import game_runner
from rgi.tests import test_utils

# Play game and record trajectory


In [5]:
# Set up game.
from rgi.games.connect4 import connect4
TGameState, TPlayerState, TAction = connect4.Connect4State, Literal[None], connect4.Action

game = connect4.Connect4Game()

In [6]:
# Run simple precomputed trajectory.
importlib.reload(trajectory)
importlib.reload(game_runner)

players = [
    test_utils.PresetPlayer[TGameState, TAction](actions=[2,2,2,2]),
    test_utils.PresetPlayer[TGameState, TAction](actions=[1,3,4,5]),
]

runner = game_runner.GameRunner(game, players, verbose=False)
original_trajectory = runner.run()

# original_trajectory = trajectory.play_game_and_record_trajectory(game, players)
original_trajectory.write('trajectory.npz', allow_pickle=False)
reloaded_trajectory = trajectory.GameTrajectory.read('trajectory.npz', TGameState, TAction, allow_pickle=True)

equality_checker = test_utils.EqualityChecker()
print('reloaded_trajectory_equal =', (reloaded_trajectory_equal := equality_checker.check_equality(original_trajectory, reloaded_trajectory)))
assert reloaded_trajectory_equal


yoooooooooooooooooooooooooooooooo
### rodo: magic: b'PK\x03\x04-\x00'
### rodo: magic is zip:
### rodo: self._files: ['action_player_ids.npy', 'incremental_rewards.npy', 'final_reward.npy', 'num_players.npy', 'state_board.npy', 'state_current_player.npy', 'state_winner.npy', 'action_.npy']
reloaded_trajectory_equal = True


# Type Scratchpad

In [13]:
xx = [1,2,3]
print(type(xx), xx)

xx: list[int] = [1,2,3]
print(type(xx), xx)

@dataclasses.dataclass
class SimpleData:
    x: list[int]

sd = SimpleData(x=[1,2,3])
print(type(sd), sd)
print(type(sd.x), sd.x)

sd.x = [1,2,3]
print(type(sd), sd)

print(dataclasses.fields(SimpleData))


<class 'list'> [1, 2, 3]
<class 'list'> [1, 2, 3]
<class '__main__.SimpleData'> SimpleData(x=[1, 2, 3])
<class 'list'> [1, 2, 3]
<class '__main__.SimpleData'> SimpleData(x=[1, 2, 3])
(Field(name='x',type=list[int],default=<dataclasses._MISSING_TYPE object at 0x7cead4db8830>,default_factory=<dataclasses._MISSING_TYPE object at 0x7cead4db8830>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),)


## Archive Prototype

In [7]:
from rgi.core import archive
importlib.reload(archive)

# mm = trajectory_archive.MemoryMappedArchive('test_archive', TGameState, TAction)
#mm.add_trajectory(original_trajectory)
#mm.save()

@dataclasses.dataclass
class TestItem:
    a: int
    b: str
    f: float
    t: tuple[int, str]
    x: np.ndarray
    y: list[int]

list_based_archive = archive.ListBasedArchive(TestItem)
list_based_archive.append(TestItem(a=1, b='hello', f=1.0, t=(1, 'a'), x=np.array([1, 2, 3]), y=[1, 2, 3]))
list_based_archive.append(TestItem(a=2, b='world', f=2.0, t=(2, 'b'), x=np.array([4, 5, 6]), y=[4, 5, 6]))

print(list_based_archive[0])
print(list_based_archive[1])
print(len(list_based_archive))

# archive.MMappedArchive.save(list_based_archive, Path('test_archive.npz'))

# mmapped_archive = archive.MMappedArchive(Path('test_archive.npz'), TestItem)
# print(mmapped_archive[0])
# print(mmapped_archive[1])
# print(len(mmapped_archive))


TestItem(a=1, b='hello', f=1.0, t=(1, 'a'), x=array([1, 2, 3]), y=[1, 2, 3])
TestItem(a=2, b='world', f=2.0, t=(2, 'b'), x=array([4, 5, 6]), y=[4, 5, 6])
2


In [11]:
T = TestItem

from dataclasses import is_dataclass, fields
from typing import Any, Dict
import numpy as np

def dataclass_to_npz_dict(obj: Any, prefix: str = "") -> Dict[str, np.ndarray]:
    if not is_dataclass(obj):
        raise TypeError(f"Expected dataclass instance, got {type(obj)}")
        
    result = {}
    for field in fields(obj):
        value = getattr(obj, field.name)
        key = f"{prefix}{field.name}"
        
        # Convert value to numpy array if needed
        if isinstance(value, np.ndarray):
            result[key] = value
        elif isinstance(value, (int, float)):
            result[key] = np.array([value])
        elif isinstance(value, (list, tuple)):
            result[key] = np.array(value)
        elif is_dataclass(value):
            # Recursively handle nested dataclasses
            nested_dict = dataclass_to_npz_dict(value, prefix=f"{key}_")
            result.update(nested_dict)
        else:
            raise TypeError(f"Unsupported type for field {field.name}: {type(value)}")
            
    return result

# Updated save method
def save(item_type: type[T], archive: archive.Archive[T], filepath: Path) -> None:
    """Save archive to file"""
    data_dict = {}
    for i, item in enumerate(archive):
        prefix = f"item_{i}_"
        item_dict = dataclass_to_npz_dict(item, prefix=prefix)
        data_dict.update(item_dict)

    print(data_dict)
    
    np.savez_compressed(filepath, **data_dict)

# save(list_based_archive, Path('test_archive.npz'))

archive_dict = {}
item_type = TestItem
for field in dataclasses.fields(item_type):
    print(field.name, field.type)
    archive_dict[field.name] = np.array([getattr(item, field.name) for item in list_based_archive])

print(archive_dict)

# np.savez_compressed('test_archive.npz', **archive_dict)



a <class 'int'>
b <class 'str'>
f <class 'float'>
t tuple[int, str]
x <class 'numpy.ndarray'>
y list[int]
{'a': array([1, 2]), 'b': array(['hello', 'world'], dtype='<U5'), 'f': array([1., 2.]), 't': array([['1', 'a'],
       ['2', 'b']], dtype='<U21'), 'x': array([[1, 2, 3],
       [4, 5, 6]]), 'y': array([[1, 2, 3],
       [4, 5, 6]])}


In [12]:
# archive_dict = {}
# item_type = TestItem
# input_archive = list_based_archive

import typing
import types
T = typing.TypeVar("T")


def serialize_to_dict(field_path: str, item_type: Type | types.GenericAlias, items: Sequence[T]) -> dict[str, Any]:
    
    if item_type in (int, float, str, bool):
        return {field_path: np.array(items)}
    
    if is_dataclass(item_type):
        d = {}
        for field in dataclasses.fields(item_type):
            # add type guard
            field_type = field.type
            if not isinstance(field_type, (Type, types.GenericAlias)):
                raise ValueError(f"Field {field.name} with field_type {field_type} is not a Type. Unable to serialize.")

            field_items = [getattr(item, field.name) for item in items]
            field_key = f"{field_path}.{field.name}"
            d[field_key] = serialize_to_dict(field_key, field_type, field_items)

            # TODO: Handle indexes.
            field_index = f"#{field_path}.{field.name}"
        return d
    
    if isinstance(item_type, types.GenericAlias):
        base_type = typing.get_origin(item_type)
        base_type_args = typing.get_args(item_type)

        # Handle variable length tuples of a single type (e.g. tuple[int, ...]) like they are a list.
        # This is primiarily to handle np.array shapes.
        if base_type is tuple and base_type_args[-1] is Ellipsis:
            if len(base_type_args) != 2:
                raise ValueError(f"Tuple with ellipsis must have exactly 2 elements, got {len(base_type_args)} and type {base_type_args}")
            list_type = list[base_type_args[0]]
            return serialize_to_dict(field_path, list_type, items)

        if base_type is tuple:
            if Ellipsis in base_type_args:
                raise ValueError(f"ellipsis only supported in tuples as a single last element, got type {item_type}")
            d = {}
            for i, t in enumerate(base_type_args):
                tuple_field_path = f"{field_path}.{i}"
                tuple_field_items = [item[i] for item in items]  # type: ignore
                tuple_serialized = serialize_to_dict(tuple_field_path, t, tuple_field_items)
                d.update(tuple_serialized)
            return d

        if base_type is list:
            unrolled_items = [item for item_list in items for item in item_list] # type: ignore
            unrolled_lengths = [len(item_list) for item_list in items]  # type: ignore
            values_dict = serialize_to_dict(f"{field_path}.*", base_type_args[0], unrolled_items)
            length_dict = serialize_to_dict(f"{field_path}.#", int, unrolled_lengths)
            return values_dict | length_dict

    if item_type is np.ndarray:
        flat_values = np.concatenate([arr.flatten() for arr in items]) # type: ignore
        shapes = [arr.shape for arr in items] # type: ignore

        values_dict = {f"{field_path}.*": flat_values}
        shape_dict = serialize_to_dict(f"{field_path}.#", tuple[int, ...], shapes)
        return values_dict | shape_dict
        
    raise NotImplementedError(f"Cannot add fields for field `{field_path}` with non-dataclass type {item_type}")


assert np.array_equal(serialize_to_dict('a', int, range(10))['a'], np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
assert np.array_equal(serialize_to_dict('a', float, [0.1, 0.2, 0.3])['a'], np.array([0.1, 0.2, 0.3]))
assert np.array_equal(serialize_to_dict('a', str, ['a', 'b', 'c'])['a'], np.array(['a', 'b', 'c']))
assert np.array_equal(serialize_to_dict('a', bool, [True, False, True])['a'], np.array([True, False, True]))


test_items = [
    TestItem(a=1, b='hello', f=1.0, t=(1, 'a'), x=np.array([1, 2, 3]), y=[1, 2, 3]),
    TestItem(a=2, b='world', f=2.0, t=(2, 'b'), x=np.array([4, 5, 6]), y=[4, 5, 6]),
    TestItem(a=3, b='foo', f=3.0, t=(3, 'c'), x=np.array([[7, 8, 9], [77, 88, 99]]), y=[7, 8, 9]),
    TestItem(a=4, b='bar', f=4.0, t=(4, 'd'), x=np.array([10, 11, 12]), y=[10, 11, 12]),
    TestItem(a=5, b='baz', f=5.0, t=(5, 'e'), x=np.array([13, 14, 15]), y=[13, 14, 15]),
]

serialize_to_dict('tt', TestItem, test_items)

{'tt.a': {'tt.a': array([1, 2, 3, 4, 5])},
 'tt.b': {'tt.b': array(['hello', 'world', 'foo', 'bar', 'baz'], dtype='<U5')},
 'tt.f': {'tt.f': array([1., 2., 3., 4., 5.])},
 'tt.t': {'tt.t.0': array([1, 2, 3, 4, 5]),
  'tt.t.1': array(['a', 'b', 'c', 'd', 'e'], dtype='<U1')},
 'tt.x': {'tt.x.*': array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 77, 88, 99, 10, 11, 12, 13, 14,
         15]),
  'tt.x.#.*': array([3, 3, 2, 3, 3, 3]),
  'tt.x.#.#': array([1, 1, 2, 1, 1])},
 'tt.y': {'tt.y.*': array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]),
  'tt.y.#': array([3, 3, 3, 3, 3])}}

In [13]:
assert isinstance(int, Type) and isinstance(float, Type) and isinstance(str, Type) and isinstance(bool, Type)
dataclasses.is_dataclass(TestItem)
# isinstance(tuple[int, str], Type)
tuple[int, str].__class__
list[int].__class__

import types

isinstance(tuple[int, str], types.GenericAlias)


True

In [25]:
import abc
import typing

class Archive(typing.Sequence[T], abc.ABC):
    pass

class AppendableArchive(Archive[T]):
    @abc.abstractmethod
    def append(self, item: T) -> None:
        """Add item to archive."""

# class ListBasedArchive(Archive[T]):
class ListBasedArchive(AppendableArchive[T]):
    def __init__(self, item_type: type[T]):
        """Initialize empty archive."""
        self._item_type = item_type
        self._items: list[T] = []

    @typing.override
    def append(self, item: T) -> None:
        """Add item to archive."""
        self._items.append(item)

    @typing.override
    def __len__(self) -> int:
        return len(self._items)

    @typing.override
    def __getitem__(self, idx: int) -> T:
        return self._items[idx]

    @typing.override
    def __repr__(self) -> str:
        return f"ListBasedArchive(item_type={self._item_type}, len={len(self)}, items[:1]={self._items[:1]})"
    

list_archive = ListBasedArchive(TestItem)

assert isinstance(list_archive, Archive)
assert isinstance(list_archive, Sequence)
print(list_archive)
list_archive.append(TestItem(a=1, b='hello', f=1.0, t=(1, 'a'), x=np.array([1, 2, 3]), y=[1, 2, 3]))
list_archive.append(TestItem(a=2, b='world', f=2.0, t=(2, 'b'), x=np.array([4, 5, 6]), y=[4, 5, 6]))
print(list_archive)
print(list(list_archive))


ListBasedArchive(item_type=<class '__main__.TestItem'>, len=0, items[:1]=[])
ListBasedArchive(item_type=<class '__main__.TestItem'>, len=2, items[:1]=[TestItem(a=1, b='hello', f=1.0, t=(1, 'a'), x=array([1, 2, 3]), y=[1, 2, 3])])
[TestItem(a=1, b='hello', f=1.0, t=(1, 'a'), x=array([1, 2, 3]), y=[1, 2, 3]), TestItem(a=2, b='world', f=2.0, t=(2, 'b'), x=array([4, 5, 6]), y=[4, 5, 6])]


In [22]:
x = [1,2,3]
print(isinstance(x, list))
print(isinstance(x, typing.Sequence))
print(isinstance(x, typing.MutableSequence))
print(isinstance(list_archive, typing.Sequence))


True
True
True
True


In [23]:
# Claude code
T = TestItem

from dataclasses import is_dataclass, fields
from typing import Any, Dict
import numpy as np

def dataclass_to_npz_dict(obj: Any, prefix: str = "") -> Dict[str, np.ndarray]:
    """Convert a dataclass instance to a dictionary of numpy arrays.
    
    Args:
        obj: Dataclass instance to convert
        prefix: Prefix for dictionary keys
        
    Returns:
        Dictionary mapping from field names to numpy arrays
    """
    if not is_dataclass(obj):
        raise TypeError(f"Expected dataclass instance, got {type(obj)}")
        
    result = {}
    for field in fields(obj):
        value = getattr(obj, field.name)
        key = f"{prefix}{field.name}"
        
        # Convert value to numpy array if needed
        if isinstance(value, np.ndarray):
            result[key] = value
        elif isinstance(value, (int, float)):
            result[key] = np.array([value])
        elif isinstance(value, (list, tuple)):
            result[key] = np.array(value)
        elif is_dataclass(value):
            # Recursively handle nested dataclasses
            nested_dict = dataclass_to_npz_dict(value, prefix=f"{key}_")
            result.update(nested_dict)
        else:
            raise TypeError(f"Unsupported type for field {field.name}: {type(value)}")
            
    return result

# Updated save method
def save(archive: archive.Archive[T], filepath: Path) -> None:
    """Save archive to file.
    
    Args:
        archive: Archive to save
        filepath: Path to save to
    """
    data_dict = {}
    for i, item in enumerate(archive):
        prefix = f"item_{i}_"
        item_dict = dataclass_to_npz_dict(item, prefix=prefix)
        data_dict.update(item_dict)

    print(data_dict)
    
    np.savez_compressed(filepath, **data_dict)

save(list_based_archive, Path('test_archive.npz'))

TypeError: Unsupported type for field b: <class 'str'>

### Old Prototype

In [None]:
from dataclasses import dataclass
from typing import Generic, TypeVar, Optional
import numpy as np
from numpy.typing import NDArray

TGameState = TypeVar("TGameState")  # pylint: disable=invalid-name
TAction = TypeVar("TAction")  # pylint: disable=invalid-name
TArchiveState = TypeVar("TArchiveState")  # pylint: disable=invalid-name
TArchiveAction = TypeVar("TArchiveAction")  # pylint: disable=invalid-name


@dataclass
class Count21State:
    score: int
    current_player: int

Count21Action = int

@dataclass
class Connect4State:
    board: NDArray[np.int8]  # (height, width)
    current_player: int
    winner: Optional[int] = None  # The winner, if the game has ended


@dataclass
class Count21ArchiveState:
    score: NDArray[np.int64]
    current_player: NDArray[np.int8]

Count21ArchiveAction = NDArray[np.int8]


@dataclass
class Archive(Generic[TGameState, TAction, TArchiveState, TArchiveAction]):
    pass

print('done.')

a = Archive[Count21State, Count21Action, Count21ArchiveState, Count21ArchiveAction]()

print(a)


In [None]:
class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Get the actual types from the class's __orig_bases__ attribute
        generic_base = self.__class__.__orig_bases__[0]
        self.game_state_type = generic_base.__args__[0]
        self.archive_state_type = generic_base.__args__[1]
        
        # Now we can inspect the fields
        print("Game State fields:", fields(self.game_state_type))
        print("Archive State fields:", fields(self.archive_state_type))
        
        # And get their type hints
        print("Game State types:", get_type_hints(self.game_state_type))
        print("Archive State types:", get_type_hints(self.archive_state_type))

# Test it
a = Archive[Count21State, Count21ArchiveState]()

In [23]:
class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Get the actual types from the class's __orig_bases__ attribute
        generic_base = self.__class__.__orig_bases__[0]
        self.game_state_type = generic_base.__args__[0]  # This is the actual Count21State class
        self.archive_state_type = generic_base.__args__[1]  # This is the actual Count21ArchiveState class
        
        # Initialize empty archive arrays based on the fields
        self.archive = {}
        game_state_fields = dataclasses.fields(self.game_state_type)
        archive_state_fields = dataclasses.fields(self.archive_state_type)
        
        # Map field names to their numpy types
        for game_field, archive_field in zip(game_state_fields, archive_state_fields):
            print(f"Field {game_field.name}: {game_field.type} -> {archive_field.type}")
            # Initialize empty array for this field
            # We'll need to determine the numpy dtype based on the archive field type

    def add(self, state: TGameState) -> None:
        # Convert each field to the appropriate numpy array
        for field in dataclasses.fields(state):
            value = getattr(state, field.name)
            print(f"Adding {field.name}: {value}")

# Test it
a = Archive[Count21State, Count21ArchiveState]()
a.add(Count21State(score=1, current_player=1))

NameError: name 'TArchiveState' is not defined

In [None]:

import dataclasses
class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Get the actual types from the class's __orig_bases__ attribute
        generic_base = self.__class__.__orig_bases__[0]
        self.game_state_type = generic_base.__args__[0]
        self.archive_state_type = generic_base.__args__[1]

        self.archive = TArchiveState()

    def add(self, state: TGameState) -> None:
        print(dataclasses.fields(state))
        print(dataclasses.fields(self.archive))


a = Archive[Count21State, Count21ArchiveState]()
a.add(Count21State(score=1, current_player=1))


In [None]:
        
class Batch(Generic[T]):
    """Convenience class to convert a sequence of states & actions into a batch.

    >>> from dataclasses import dataclass
    >>> import torch
    >>> @dataclass
    ... class GameState:
    ...     score: int
    ...     current_player: int
    >>> @dataclass
    ... class BatchGameState(Batch[GameState]):
    ...     score: torch.Tensor
    ...     current_player: torch.Tensor
    >>> states = [GameState(5, 1), GameState(7, 2)]
    >>> batch = BatchGameState.from_sequence(states)
    >>> len(batch)
    2
    >>> batch
    BatchGameState(score=tensor([5, 7]), current_player=tensor([1, 2]))
    >>> batch[0]
    GameState(score=5, current_player=1)
    """

    _unbatch_class: Type[T]

    @classmethod
    def from_sequence(cls: Type[TBatch], items: Sequence[T]) -> TBatch:
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        cls_fields = set(f.name for f in fields(cls))  # type: ignore
        batch_dict = {}
        for field in fields(items[0]):  # type: ignore
            if field.name not in cls_fields:
                continue
            values = [getattr(item, field.name) for item in items]
            # We need to handle both primitive values and torch.Tensors here.
            # torch.tensor(primitive_list) is probably more efficient, but doesn't work for tensors.
            batch_dict[field.name] = torch.stack([torch.tensor(value) for value in values])

        batch = cls(**batch_dict)
        batch._unbatch_class = type(items[0])
        return batch

    def __getitem__(self, index: int) -> T:
        item_dict = {field.name: field.type(getattr(self, field.name)[index]) for field in fields(self)}  # type: ignore
        return self._unbatch_class(**item_dict)

    def __len__(self) -> int:
        return len(getattr(self, fields(self)[0].name))  # type: ignore

In [None]:
class Batchable(Protocol[T]):
    """Protocol to convert single states & actions into torch.Tensor for batching."""

    @staticmethod
    def from_sequence(items: Sequence[T]) -> "Batchable[T]": ...

    def __getitem__(self, index: int) -> T: ...

    def __len__(self) -> int: ...


class Batch(Generic[T]):
    """Convenience class to convert a sequence of states & actions into a batch.

    >>> from dataclasses import dataclass
    >>> import torch
    >>> @dataclass
    ... class GameState:
    ...     score: int
    ...     current_player: int
    >>> @dataclass
    ... class BatchGameState(Batch[GameState]):
    ...     score: torch.Tensor
    ...     current_player: torch.Tensor
    >>> states = [GameState(5, 1), GameState(7, 2)]
    >>> batch = BatchGameState.from_sequence(states)
    >>> len(batch)
    2
    >>> batch
    BatchGameState(score=tensor([5, 7]), current_player=tensor([1, 2]))
    >>> batch[0]
    GameState(score=5, current_player=1)
    """

    _unbatch_class: Type[T]

    @classmethod
    def from_sequence(cls: Type[TBatch], items: Sequence[T]) -> TBatch:
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        cls_fields = set(f.name for f in fields(cls))  # type: ignore
        batch_dict = {}
        for field in fields(items[0]):  # type: ignore
            if field.name not in cls_fields:
                continue
            values = [getattr(item, field.name) for item in items]
            # We need to handle both primitive values and torch.Tensors here.
            # torch.tensor(primitive_list) is probably more efficient, but doesn't work for tensors.
            batch_dict[field.name] = torch.stack([torch.tensor(value) for value in values])

        batch = cls(**batch_dict)
        batch._unbatch_class = type(items[0])
        return batch

    def __getitem__(self, index: int) -> T:
        item_dict = {field.name: field.type(getattr(self, field.name)[index]) for field in fields(self)}  # type: ignore
        return self._unbatch_class(**item_dict)

    def __len__(self) -> int:
        return len(getattr(self, fields(self)[0].name))  # type: ignore

@dataclass
class PrimitiveBatch(Generic[T]):
    """A batch class for primitive types like int, float, etc.

    >>> batch = PrimitiveBatch.from_sequence([2,4,6,8])
    >>> len(batch)
    4
    >>> batch
    PrimitiveBatch(values=tensor([2, 4, 6, 8]))
    >>> batch[0]
    2
    """

    values: torch.Tensor

    @classmethod
    def from_sequence(cls: Type["PrimitiveBatch[T]"], items: Sequence[T]) -> "PrimitiveBatch[T]":
        if not items:
            raise ValueError("Cannot create a batch from an empty sequence")

        return cls(values=torch.tensor(items))

    def __getitem__(self, index: int) -> T:
        return self.values[index].item()  # type: ignore

    def __len__(self) -> int:
        return self.values.shape[0]


In [None]:
from typing import TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

class Archive(Generic[TGameState, TArchiveState]):
    def __init__(self):
        # Access the original class with type parameters
        orig_class = self.__orig_class__
        # Extract the type arguments
        t_game_state, t_archive_state = orig_class.__args__
        print(f"TGameState type: {t_game_state}")
        print(f"TArchiveState type: {t_archive_state}")

# Example classes
class Foo:
    pass

class Bar:
    pass

# Create an instance with specific types
a = Archive[Foo, Bar]()


In [None]:
from dataclasses import dataclass, field
from typing import Type, TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

@dataclass
class Archive(Generic[TGameState, TArchiveState]):
    _game_state_type: Type[TGameState] = field(init=False, repr=False)
    _archive_state_type: Type[TArchiveState] = field(init=False, repr=False)

    def __init__(self, game_state_type: Type[TGameState], archive_state_type: Type[TArchiveState]):
        self._game_state_type = game_state_type
        self._archive_state_type = archive_state_type
        print(f"TGameState type: {self._game_state_type}")
        print(f"TArchiveState type: {self._archive_state_type}")

# Define Foo and Bar as example dataclasses
@dataclass
class Foo:
    name: str = "Foo example"

@dataclass
class Bar:
    value: int = 42

# Instantiate Archive by passing the types
a = Archive(Foo, Bar)


In [None]:
from dataclasses import dataclass, field
from typing import Type, TypeVar, Generic

TGameState = TypeVar('TGameState')
TArchiveState = TypeVar('TArchiveState')

class Archive(Generic[TGameState, TArchiveState]):

    def __init__(self, game_state_type: Type[TGameState], archive_state_type: Type[TArchiveState]):
        self._game_state_type = game_state_type
        self._archive_state_type = archive_state_type
        print(f"TGameState type: {self._game_state_type}")
        print(f"TArchiveState type: {self._archive_state_type}")

# Define Foo and Bar as example dataclasses
@dataclass
class Foo:
    name: str = "Foo example"

@dataclass
class Bar:
    value: int = 42

# Instantiate Archive by passing the types
a = Archive(Foo, Bar)
