In [1]:
from typing import Sequence
from abstract_dataloader import spec
from typing import Generic


class _Sequential(spec.Pipeline):
    def __init__(
        self, steps: Sequence[spec.Pipeline | spec.Transform]
    ) -> None:

        self._sample = []
        self._batch = []
        _collated = False

        for step in steps:
            if not _collated:
                if isinstance(step, spec.Pipeline):
                    self._sample.append(step.sample)
                    self.collate = step.collate
                    self._batch.append(step.batch)
                    _collated = True
                else:
                    self._batch.append(step)
            else:
                if isinstance(step, spec.Pipeline):
                    raise TypeError(
                        "Cannot collate multiple times in a sequential "
                        f"pipeline: {[type(s).__name__ for s in steps]}.")
                self._batch.append(step)

    def sample(self, data):
        for step in self._sample:
            data = step(data)
        return data

    def batch(self, data):
        for step in self._batch:
            data = step(data)
        return data


In [2]:
from collections.abc import Mapping, Sequence
from typing import cast

from abstract_dataloader import abstract, generic, spec

NestedSpec = (
    spec.Pipeline | spec.Transform
    | Sequence['NestedSpec'] | Mapping[str, 'NestedSpec'])


def _compose(transforms: NestedSpec) -> spec.Pipeline | spec.Transform:

    if isinstance(transforms, Sequence):
        transforms = [_compose(t) for t in transforms]
        if any(isinstance(t, spec.Pipeline) for t in transforms):
            return _Sequential(transforms)
        else:
            transforms = cast(Sequence[spec.Transform], transforms)
            return abstract.Transform(transforms)
    elif isinstance(transforms, Mapping):
        transforms = {k: _compose(v) for k, v in transforms.items()}

        if all(isinstance(v, spec.Pipeline) for v in transforms.values()):
            transforms = cast(Mapping[str, spec.Pipeline], transforms)
            return generic.ParallelPipelines(**transforms)
        elif all(isinstance(v, spec.Transform) for v in transforms.values()):
            transforms = cast(Mapping[str, spec.Transform], transforms)
            return generic.ParallelTransforms(**transforms)
        else:
            type_desc = {k: type(v).__name__ for k, v in transforms.items()}
            raise TypeError(
                "Parallel transforms have mixed types: "
                f"{type_desc}.")
    else:
        return transforms


In [3]:
class TestTransform:
    def __init__(self, name: str):
        self.name = name

    def __call__(self, data):
            return data

    def __repr__(self) -> str:
        return self.name


class TestPipeline:
    def __init__(self, name: str):
        self.name = name

    def sample(self, data):
        return data

    def batch(self, data):
        return data

    def collate(self, data):
        return data

    def __repr__(self) -> str:
        return self.name



In [22]:
_compose({"a": [TestTransform("t1"), TestPipeline("p1")], "b": TestPipeline("p2")}).collate([{"a": 1, "b": 2}, {"a": 1, "b": 2}])

{'a': [1, 1], 'b': [2, 2]}