In [None]:
import pickle
from collections import namedtuple


class TupleSplitter:
    r"""Splits a tuple into namedtuple, given by the groups."""

    def __init__(self, groups: dict[str, list[int]], *, name="Groups"):
        self.groups = groups
        self.tuple = namedtuple(name, groups)

        # create a unique identifier and store it in globals
        self.identifier = f"_{self.tuple.__name__}_{hash(self)}"
        self.tuple.__qualname__ = self.identifier

        if self.identifier in globals():
            raise RuntimeError(f"A class of name '{self.identifier}' exists!!")
        globals()[self.identifier] = self.tuple

    def __call__(self, x: tuple) -> tuple:
        return self.tuple(
            **{key: tuple(x[k] for k in group) for key, group in self.groups.items()}
        )

    def __del__(self):
        """Delete the globals entry when the class instance is deleted."""
        identifier = self.tuple.__qualname__
        if identifier not in globals():
            raise RuntimeError(f"'{identifier}' was already purged!")
        del globals()[identifier]
        del self


encoder = TupleSplitter({"a": [0, 1], "b": [2]})
groups1 = encoder(("foo1", "bar1", "baz1"))
groups2 = encoder(("foo2", "bar2", "baz2"))
pickle1 = pickle.dumps(groups1)
pickle2 = pickle.dumps(groups2)
tuple1 = pickle.loads(pickle1)
tuple2 = pickle.loads(pickle2)
assert type(groups1) == type(groups2)
assert type(tuple1) == type(tuple2)
assert type(tuple1) == type(groups1)
assert tuple1 == groups1

del encoder
gc.collect()
print(sys.getrefcount(identifier))
dir(__main__)

In [None]:
def doit():
    encoder = TupleSplitter({"a": [0, 1], "b": [2]})
    groups1 = encoder(("foo1", "bar1", "baz1"))
    groups2 = encoder(("foo2", "bar2", "baz2"))
    # pickle1 = pickle.dumps(groups1)
    # pickle2 = pickle.dumps(groups2)
    # tuple1 = pickle.loads(pickle1)
    # tuple2 = pickle.loads(pickle2)

In [None]:
from tqdm.autonotebook import trange

In [None]:
import os
import time

import psutil

In [None]:
for k in (pbar := trange(1_000_000)):
    doit()
    if k % 10_000 == 0:
        process = psutil.Process(os.getpid())
        memory = process.memory_info().rss  # in bytes
        pbar.set_postfix(memory=f"{memory // 1024**2} MiB")

In [None]:
import copyreg
import pickle
from collections import namedtuple
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import Any, Optional

import torch
from pandas import DataFrame, MultiIndex, Series
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from tsdm.utils.strings import repr_mapping

In [None]:
class TupleDataset(Dataset[tuple[Tensor, ...]]):
    r"""Sequential Dataset."""

    def __init__(
        self,
        **tensors: dict[str, Tensor],
    ):

        first = next(iter(tensors.values()))
        self.LEN = len(first)

        assert all(len(tensor) == len(first) for tensor in tensors.values())

        self.tensors = tensors
        self.tuple = namedtuple("Sample", tensors.keys())
        # copyreg.pickle(self.tuple, namedtuple)
        # copyreg.constructor(self.tuple)

        tuple_name = f"{self.tuple.__name__}"  # {hash(self.tuple)}"
        tuple_qualname = f"{self.tuple.__name__}{hash(self.tuple)}"
        self.tuple.__qualname__ = tuple_qualname

        if tuple_qualname in globals():
            raise RuntimeError(
                f"A class of name '{tuple_name}' already present in globals!!"
            )
        globals()[tuple_qualname] = self.tuple

    def __len__(self):
        r"""Length of the dataset."""
        return self.LEN

    def __getitem__(self, idx) -> tuple[Tensor, ...]:
        r"""Get the same slice from each tensor."""
        return self.tuple(**{key: tensor[idx] for key, tensor in self.tensors.items()})

In [None]:
t = torch.randn(100)
x = torch.randn(100, 5)

In [None]:
ds = TupleDataset(t=t, x=x)
sample = next(iter(ds))
pickle.dumps(sample)
dloader = DataLoader(ds, batch_size=10, num_workers=5)
iloader = iter(dloader)
first = next(iloader)
second = next(iloader)

for sample in iloader:
    pass

type(first), type(second), type(sample)