# Idea

We can transform this to a decorator???

In [None]:
from abc import ABCMeta
from collections.abc import *
from dataclasses import KW_ONLY, dataclass, field
from typing import Any, Dict, Final, Optional, TypeVar, Union

from pydantic.dataclasses import dataclass as pydantic_dataclass
from torch import Tensor, jit, nn

In [None]:
Key = TypeVar("Key", bound=Hashable)


def flatten_dict(
    d: dict[str, Iterable[Any]],
    /,
    *,
    recursive: bool = True,
    how: Callable[[Iterable[Key]], Key] = ".".join,
) -> dict[tuple[str, ...], Any]:
    r"""Flatten a dictionary containing iterables to a list of tuples.

    Parameters
    ----------
    d: dict
    recursive: bool (default=True)
        If true applies flattening strategy recursively on nested dicts, yielding
        list[tuple[key1, key2, ...., keyN, value]]

    Returns
    -------
    list[tuple[Any, ...]]
    """
    result = {}
    for key, item in d.items():
        if isinstance(key, tuple):
            raise ValueError("Keys are not allowed to be tuples!")
        if isinstance(item, dict) and recursive:
            subdict = flatten_dict(item, recursive=True, how=how)
            for subkey, subitem in subdict.items():
                result[how((key, subkey))] = subitem
        else:
            result[key] = item
    return result


def unflatten_dict(
    d: dict[Hashable, Iterable[Any]],
    /,
    *,
    recursive: bool = True,
    how: Callable[[Key], Iterable[Key]] = str.split,
) -> list[tuple[Any, ...]]:
    r"""Flatten a dictionary containing iterables to a list of tuples.

    Parameters
    ----------
    d: dict
    recursive: bool (default=True)
        If true applies flattening strategy recursively on nested dicts, yielding
        list[tuple[key1, key2, ...., keyN, value]]

    Returns
    -------
    list[tuple[Any, ...]]
    """
    result = {}
    for key, item in d.items():
        if isinstance(key, tuple):
            if key[0] not in result:
                result[key[0]] = {}

            # subdict = {subkey[1:]: subitem for subkey, subitem in item.items()}
            if len(key) <= 1:
                result[key[0]] |= {key[1:]: item}
            else:
                result[key[0]] |= unflatten_dict({key[1:]: item})

        else:
            result[key] = item
    return result

### Testing flattening function

In [None]:
test_dict: dict[str, Any] = {
    "foo": 2,
    "bar": True,
    "baz": 0.99,
    "nested": {
        "foo": 1,
        "bar": True,
        "baz": 0.99,
        "nested": {"foo": 1, "bar": True, "baz": 0.99},
    },
}

flat = flatten_dict(test_dict, how=".".join)
display(flat)
display(unflatten_dict(flat))

In [None]:
def test_model_attribute(model: nn.Module, attr: str) -> None:
    r"""Check whether attr is maintained under torhc JIT"""

    original_model = model

    if hasattr(original_model, attr):
        attribute = getattr(original_model, attr)
        print(f"INITIAL Model.{attr}={attribute}")
    else:
        print(f"INITIAL Model does not have attribute '{attr}'.")

    serialized_model = jit.script(model)

    if hasattr(serialized_model, attr):
        attribute = getattr(serialized_model, attr)
        print(f"JITTED  Model.{attr}={attribute}")
    else:
        print(f"JITTED  Model does not have attribute '{attr}'.")

    jit.save(serialized_model, "model.pt")
    loaded_model = jit.load("model.pt")

    if hasattr(loaded_model, attr):
        attribute = getattr(loaded_model, attr)
        print(f"LOADED  Model.{attr}={attribute}")
    else:
        print(f"LOADED  Model does not have attribute '{attr}'.")

In [None]:
class DemoModel(nn.Module):
    FOO: Dict[str, Any] = {
        "a": 2,
        "b": True,
        "c": 0.99,
        "d": {"a": 1, "b": True, "c": 0.99, "d": {"a": 1}},
    }

    BAR: Dict[str, Union[str, int, bool, float]]
    # BAZ: Dict[str, Any]

    def __init__(self) -> None:
        self.BAR = flatten_dict(self.FOO)
        super().__init__()
        # self.BAZ = self.FOO

    def forward(self, x: Tensor) -> Tensor:
        r"""Simply identity"""
        return x


test_model_attribute(DemoModel(), "FOO")
test_model_attribute(DemoModel(), "BAR")
test_model_attribute(DemoModel(), "BAZ")

# Can we serialize models with attached dictionary?

⟹ only when flattened!

# Observations

- As of **`torch=1.12.1`**, `typing.Dict` works but `dict` doesn't?
- AS of **`torch=1.12.1`**, nested dictionaries are not supported.
- As of **`torch=1.12.1`**, tracing only works if
    - The complete dictionary is added in the class 
    - The dictionary is not annotated.
- It makes no difference whether the dictionary is added before or after `super().__init__` is called


## Idea

We need to do 2 things for a robust initialization

1. Input validation: in particular, cast values to correct type
    - Could use pydantic
        - Does not seem to support `KW_ONLY` yet.
        - Use regular `DataClasses` in the meantime?
        - ~~Alternative `TypedDict`?~~
            - `TypedDict` do not allow extra keys... https://github.com/python/mypy/issues/4617
2. Sub-module compatibility: e.g. module might need to be written such that latent size agrees with output size of other module.
    - If values are given, check for compatibility.
    - If values are not given, use defaults.
        - `if "input_size" in **kwargs:... else ...`
    - Do we want to support mixed inputs? (e.g. `encoder=<some nn.Module>`, `activation=<module_dict>`)
3. Module Creation.
    - Should be the responsibility of `from_hyperparameters` / `__new__` / `__init__`.
4. Every Module should sport a default config
    - Other Modules should be able to use this config, e.g. `encoder=MyEncoderModel.Config`
        - If uninitialized class is given, the initialize with its default dict.
        - If initialized class is given, use it as is.
        - If config / dictionary is given, use it to locate the module and initialize it.
5. Serialization: 
    - The Dataclass / NestedDict should work arbitrary Optional data.
6. Should positional arguments be allowed or only `*args`?
8. Optional cool stuff:
    - Signatures and automatic signature validation.
        - Only makes sense post-initialization.

## Specification

- Need a class object that maps 1:1 to a nested `dict` / `json` / `toml` file.
    - Implement conversion utilities (⇝ pydantic.)
- Should have a fixed set of required values, but allow optional values (with some naming restrictions)
- Certain values (`__name__`, `__qualname__`, `__module__`) should always be included.

How object is created.

- Option 1: Subclass a `BaseConfig` class that implements the `__name__` logic.
    - How to pass `*args`, `**kwargs`?
    ```python
    class MyModel(nn.Module):
        Config(BaseConfig):
            a: int,
            b: int, 
            *args: Any
            droprate: float = 0.2
            **kwargs: Any
    ```

- Option 2: DataClass / TypedDict
- Option 3: Instantiate a class locally? (How to pass Type Hints?)
- Option 4: Class Factory. (might actually be best?) Issue: doesn't work syntactically.
    
    ```python
    class MyModel(nn.Module):
        Config = create_config(
            a: int,
            b: int, 
            *args: Any
            droprate: float = 0.2
            **kwargs: Any
        )
    ```
    

## How to deal with missing values?

```python
class MyModel(nn.Module):
    Config:
        input_size: int
        drop_rate: float = 0.5
```

This Model requires getting input_size as an input. 
- How do we initialize it? 
- How does the default config in dictionary form look like?
    - Should we even be able to serialize it with missing keys?
- What value do we put for missing?


```python
class MyModel:    
    class Config:
        ...
    
Class MyOtherModel:

    class Config:
        encoder: MyModel
```

# Solution

We should be able to initialize submodules by either:

1. Passing a `Config`
2. Passing an already initialized model.
3. Older: passing a string. (e.g. `activation='relu'`)


- Actually used config should be searializable (currently: cast to nested dict -> flatten.)

During init, do the following:

- Validate entries
- Fill-in mandatory fields to submodules from parameters passed to parent class. **"dependent fields"**

## Questions

- Do we want positional only parameters ? 
- Since some fields are dependent, how should they be passed?
    - E.g. We want to initialize a model, some parameters are Mandatory!
    - We want the `__name__` and `__module__` be added automatically, if possible.
        - Use protocols?
    
    
- How to easily overwrite fields of the default configuration?
    - Idea: Allow overwriting of fields by passing filesystem like keys:

```python
"Filter/kernel_initialization" : "symmetric".
```


=> We will need to write a custom Config object...

    - initializing it validates fields
    - 


### Example

```python
MODEL_CONFIG = LinODEnet.DefaultConfig(
    input_size: ...
    hidden_size: ...
    Filter = SequentialFilter.DefaultConfig(kernel_init=...)
    __default_updates__ = {"Filter/kernel_initialization" : "symmetric"}.
)


model = Model(**Model_CONFIG)


{
    "__name__": "LinODEnet",
    "input_size": TASK.dataset.shape[-1],
    "hidden_size": ARGS.hidden_size,
    "embedding_type": "concat",
    "Filter": filters.SequentialFilter.HP,
    "System": system.LinODECell.HP | {"kernel_initialization": ARGS.kernel_init},
    "Encoder": ResNet.HP,
    "Decoder": ResNet.HP,
    "Embedding": embeddings.ConcatEmbedding.HP,
}
```

In [None]:
import keyword
import types


def make_dataclass(
    cls_name,
    fields,
    *,
    bases=(),
    namespace=None,
    init=True,
    repr=True,
    eq=True,
    order=False,
    unsafe_hash=False,
    frozen=False,
    match_args=True,
    kw_only=False,
    slots=False,
    **kwds,
):
    """Return a new dynamically created dataclass.
    The dataclass name will be 'cls_name'.  'fields' is an iterable
    of either (name), (name, type) or (name, type, Field) objects. If type is
    omitted, use the string 'typing.Any'.  Field objects are created by
    the equivalent of calling 'field(name, type [, Field-info])'.
      C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,))
    is equivalent to:
      @dataclass
      class C(Base):
          x: 'typing.Any'
          y: int
          z: int = field(init=False)
    For the bases and namespace parameters, see the builtin type() function.
    The parameters init, repr, eq, order, unsafe_hash, and frozen are passed to
    dataclass().
    """

    if namespace is None:
        namespace = {}

    # While we're looking through the field names, validate that they
    # are identifiers, are not keywords, and not duplicates.
    seen = set()
    annotations = {}
    defaults = {}
    for item in fields:
        if isinstance(item, str):
            name = item
            tp = "typing.Any"
        elif len(item) == 2:
            (
                name,
                tp,
            ) = item
        elif len(item) == 3:
            name, tp, spec = item
            defaults[name] = spec
        else:
            raise TypeError(f"Invalid field: {item!r}")

        if not isinstance(name, str) or not name.isidentifier():
            raise TypeError(f"Field names must be valid identifiers: {name!r}")
        if keyword.iskeyword(name):
            raise TypeError(f"Field names must not be keywords: {name!r}")
        if name in seen:
            raise TypeError(f"Field name duplicated: {name!r}")

        seen.add(name)
        annotations[name] = tp

    # Update 'ns' with the user-supplied namespace plus our calculated values.
    def exec_body_callback(ns):
        ns.update(namespace)
        ns.update(defaults)
        ns["__annotations__"] = annotations

    # We use `types.new_class()` instead of simply `type()` to allow dynamic creation
    # of generic dataclasses.
    cls = types.new_class(cls_name, bases, kwds, exec_body_callback)

    # Apply the normal decorator.
    return dataclass(
        cls,
        init=init,
        repr=repr,
        eq=eq,
        order=order,
        unsafe_hash=unsafe_hash,
        frozen=frozen,
        match_args=match_args,
        kw_only=kw_only,
        slots=slots,
    )

# The trick!

We create a class `Config`, but, this object behaves unlike other python objects!

1. Config are dependent classes that should only exist attached to a parent class.
2. Configs are immutable Mapping-Like objects, changing keys returns a different object!
3. Configs are similar to dataclasses. However, a key difference is that they can be initialized with a `MISSING` object.
4. Configs can map **1:1** to nested dictionaries of elementary types
5. Configs can map **1:1** to `json` files.
6. Configs offer functionality to update nested entries easily `{"obj/subobj/subsubobj" : "value"}`
7. Configs can be mapped to flattened / unflattened representation.
8. Configs only allow keys that are not `__dunder__`.
9. Keys ~~`__name__` and `__module__` are automatically added at class definition time.~~
    - Reserve ALLCAPS fields.
10. Configs offer a `validate` option that checks the types.

Usage:


```python
Foo:
    Config(BaseConfig):
        input_size: int
        output_size: int
        ...
        
```

Then `Foo.Config` is a type.




In [None]:
from typing import ClassVar


class BaseConfigMetaclass(ABCMeta):
    """This Metaclass does a few things:

    1. Makes sure BaseClass and all subclasses are DataClasses.
    """

    _PROTECTED_KEYS = {
        "__module__",
        "__qualname__",
        "__annotations__",
        "__doc__",
    }

    # fmt: off
    _FORBIDDEN_KEYS = {

        "clear",       #  Removes all the elements from the dictionary
        "copy",        #  Returns a copy of the dictionary
        "fromkeys",    #  Returns a dictionary with the specified keys and value
        "get",         #  Returns the value of the specified key
        "items",       #  Returns a list containing a tuple for each key value pair
        "keys",        #  Returns a list containing the dictionary's keys
        "pop",         #  Removes the element with the specified key
        "popitem",     #  Removes the last inserted key-value pair
        "setdefault",  #  Returns the value of the specified key. If the key does not exist: insert the key, with the specified value
        "update",      #  Updates the dictionary with the specified key-value pairs
        "values",      #  Returns a list of all the values in the dictionary
    }
    # fmt: on

    def __new__(cls, name, bases, attrs, **kwds):
        print(">>>>>>>>>> NEW CALLED <<<<<<<<<<<<<<<")
        
        if "__annotations__" not in attrs:
            attrs["__annotations__"] = {}
        
        display(f"{cls=}")
        display(f"{dir(cls)=}")
        # display(f"{args=}")
        display(f"{cls.__qualname__=}")
        display(f"{name=}")
        display(f"{bases=}")
        display(f"{set(attrs)=}")
        display(f"{attrs['__annotations__']=}")

        if '__slots__' in attrs:
            display(f"{attrs['__slots__']=}")
        if '__qualname__' in attrs:
            display(f"{attrs['__qualname__']=}")
        display(f"{kwds=}")
        
        print(">>>>>>>>>>  CREATED NEW TYPE <<<<<<<<<<<<<<<")
        newtype = super().__new__(cls, name, bases, attrs, **kwds)
        display(f"{newtype=}")
        display(f"{dir(newtype)=}")
        display(f"{newtype.__qualname__=}")
        parents = newtype.__qualname__.rsplit(".", maxsplit=1)
        parent = None if len(parents) == 1 else parents[0]

        patched_fields = {
            "_" : ("_", KW_ONLY),
            "NAME" :(
                "NAME",
                str,
                field(default=attrs['__qualname__']))
            "MODULE": (
                "MODULE",
                str,
                field(default=attrs['__module__']),
            ),

        }

        
        

        fields = [
            (key, hint, attrs[key]) if key in attrs else (key, hint)
            for key, hint in attrs["__annotations__"].items()
        ]

        for key in patched_fields:
            if key not in attrs["__annotations__"]:
                fields.append(patched_fields[key])

        display(f"{fields=}")
        display(f"{patched_fields=}")

        patched_attrs = {
            key: value for key, value in attrs.items() if key not in patched_fields
        }
        
                
        KEYS = set(attrs['__annotations__']) - cls._PROTECTED_KEYS

        if KEYS & cls._FORBIDDEN_KEYS:
            raise ValueError(
                f"Keys '{FORBIDDEN_KEYS}' are not allowed! Found: '{KEYS & cls._FORBIDDEN_KEYS}'"
            )

        DUNDER_KEYS = {
            key for key in KEYS if is_dunder(key) and key not in patched_fields
        }
        if DUNDER_KEYS:
            raise ValueError(f"Dunder fields are not allowed, found {DUNDER_KEYS}")

#         config_type = make_dataclass(
#             name, fields, bases=bases, namespace=attrs, eq=False, frozen=True
#         )

        for key in patched_fields:
            if key not in attrs["__annotations__"]:
                # fields.append(patched_fields[key])
                newtype.__annotations__[key] = patched_fields[key][1]
                setattr(newtype, key, patched_fields[key][-1])

                
        display(f"{newtype.__annotations__=}")
        config_type = dataclass(newtype, eq=False, frozen=True)
        
        
        display(f"{config_type=}")
        
        return config_type

In [None]:
def is_allcaps(s: str) -> bool:
    return s.isidentifier() and s.isupper() and s.isalpha()


def is_dunder(s: str) -> bool:
    return s.isidentifier() and s.startswith("__") and s.endswith("__")


class BaseConfigMetaclass(ABCMeta):
    """This Metaclass does a few things:

    1. Makes sure BaseClass and all subclasses are DataClasses.
    """

    # fmt: off
    _FORBIDDEN_FIELDS = {
        "clear",       #  Removes all the elements from the dictionary
        "copy",        #  Returns a copy of the dictionary
        "fromkeys",    #  Returns a dictionary with the specified keys and value
        "get",         #  Returns the value of the specified key
        "items",       #  Returns a list containing a tuple for each key value pair
        "keys",        #  Returns a list containing the dictionary's keys
        "pop",         #  Removes the element with the specified key
        "popitem",     #  Removes the last inserted key-value pair
        "setdefault",  #  Returns the value of the specified key. If the key does not exist: insert the key, with the specified value
        "update",      #  Updates the dictionary with the specified key-value pairs
        "values",      #  Returns a list of all the values in the dictionary
    }
    # fmt: on

    def __new__(cls, name, bases, attrs, **kwds):
        if "__annotations__" not in attrs:
            attrs["__annotations__"] = {}

        newtype = super().__new__(cls, name, bases, attrs, **kwds)
        FIELDS = set(attrs["__annotations__"])

        # check forbidden fields
        FORBIDDEN_FIELDS = cls._FORBIDDEN_FIELDS & FIELDS
        if FORBIDDEN_FIELDS:
            raise ValueError(
                f"Fields '{cls._FORBIDDEN_FIELDS}' are not allowed! "
                f"Found '{FORBIDDEN_FIELDS}'"
            )

        # check for dunder fields
        DUNDER_FIELDS = {key for key in FIELDS if is_dunder(key)}
        if DUNDER_FIELDS:
            raise ValueError(f"Dunder fields are not allowed!Found '{DUNDER_KEYS}'.")

        # check all caps fields
        ALLCAPS_FIELDS = {key for key in FIELDS if is_allcaps(key)}
        if ALLCAPS_FIELDS:
            raise ValueError(f"ALLCAPS fields are reserved!Found '{ALLCAPS_FIELDS}'.")

        NAME = newtype.__qualname__.rsplit(".", maxsplit=1)[0]
        patched_fields = [
            ("_", KW_ONLY),
            ("NAME", str, field(default=NAME)),
            ("MODULE", str, field(default=attrs["__module__"])),
        ]

        for key, hint, *value in patched_fields:
            newtype.__annotations__[key] = hint
            if value:
                setattr(newtype, key, value[0])

        return dataclass(newtype, eq=False, frozen=True)

In [None]:
class BaseConfig(Mapping, metaclass=BaseConfigMetaclass):
    """Base Config"""

    def __iter__(self):
        return iter(self.__dict__)

    def __getitem__(self, key):
        return self.__dict__[key]

    def __len__(self) -> int:
        return len(self.__dict__)

    # def __eq__(self, other):
    #     print(f"HERE!!!")
    #     if not isinstance(other, Mapping):
    #         return NotImplemented
    #     return dict(self.items()) == dict(other.items())

    def __hash__(self) -> int:
        r"""Return permutation-invariant hash on `items()`."""
        return hash(frozenset(self.items()))

    def __or__(self, other):
        res = {}
        res.update(self)
        res.update(other)
        return self.__class__(**res)


BaseConfig()

In [None]:
class MyConfig(BaseConfig):
    """."""

    # __name__: str = "demo"
    c: str = "test"
    # keys: str = "a"


class SubConfig(MyConfig):
    d: str = "a"


class Foo:
    class MyConfig(BaseConfig):
        # __name__: str = "demo"
        _: KW_ONLY
        c: str = "test"

    class SubConfig(MyConfig):
        """."""

        # __name__: str = "demo"
        d: str
        e: str = "demo"
        # c: str = "test"


Foo.SubConfig(1) == dict(Foo.SubConfig(1))

In [None]:
@dataclass
class MyConfig:
    _: KW_ONLY
    c: str = "test"


@dataclass
class SubConfig(MyConfig):
    d: str


SubConfig()

In [None]:
MyConfig.__annotations__

In [None]:
Foo.MyConfig()

In [None]:
Foo.MyConfig().__dict__

In [None]:
Foo.MyConfig().__dataclass_fields__

In [None]:
a = {"c": "test"}
b = Foo.MyConfig(**a)

set(dir(a)) - set(dir(b))

In [None]:
from collections.abc import Mapping
from dataclasses import dataclass


@dataclass(eq=False, frozen=True)
class Foo(Mapping):
    c: str = "test"

    def __iter__(self):
        return iter(self.__dict__)

    def __getitem__(self, key):
        return self.__dict__[key]

    def __len__(self) -> int:
        return len(self.__dict__)

In [None]:
a = {"c": "test"}
b = Foo(**a)

print(b.__dict__)
assert b.__eq__(a)  # ✔
assert a.__eq__(b)  # DeprecationWarning: NotImplemented
assert dict(b) == a  # ✔
assert list(b.keys()) == list(a.keys())  # ✔
assert list(b.values()) == list(a.values())  # ✔
assert list(b.items()) == list(a.items())  # ✔
assert b.keys() == a.keys()  # ✔
assert b.values().__eq__(a.values())  # DeprecationWarning: NotImplemented
assert a.values().__eq__(b.values())  # DeprecationWarning: NotImplemented
assert b.values() == a.values()  # ✘  Fails!!
assert b.items() == a.items()  # ✔
assert dict(b.items()) == dict(a.items())  # ✔
assert b == a  # ✔

In [None]:
class Test(Mapping):
    content: dict

    def __init__(self, content):
        self.content = content

    def __getitem__(self, key):
        return self.content[key]

    def __len__(self):
        return len(self.content)

    def __iter__(self):
        return iter(self.content)

# Using Pydantic

In [None]:
class Foo:
    class Bar:
        @dataclass
        class Config(BaseConfig):
            input_size: int
            output_size: int
            latent_size: Optional[int] = None
            num_layers: int = 2
            activation: str = Config(1, 2)

In [None]:
@dataclass(frozen=True)
class Foo:
    __name__: str = "lol"
    a: str = "xd"

In [None]:
class ResNetBlock(nn.Module):
    """ResNet Block Model"""

    @pydantic_dataclass
    class Config:
        input_size: int
        bottleneck_size: Optional[int] = None
        num_layers: int = 2
        activation: str | Config = "relu"

    def __init__(self):
        super().__init__()

        self.layerA = nn.Linear()
        self.layerB = nn.Linear()
        self.activation = ...


from_config()

In [None]:
class ResNet(nn.ModuleList):
    r"""A ResNet model."""

    @pydantic_dataclass
    class Config:
        input_size: int
        output_size: int
        latent_size: Optional[int] = None
        num_layers: int = 2
        activation: str = "relu"

    HP: Final[dict] = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "input_size": None,
        "num_blocks": 5,
        # "blocks": ResNetBlock.HP,
    }

    def __new__(cls, *blocks, **hparams):
        r"""Initialize from hyperparameters."""
        assert len(blocks) ^ len(hparams), "Provide either blocks, or hyperparameters!"

        if hparams:
            return cls.from_hyperparameters(**hparams)

        return super().__new__(cls)

    def __init__(self, *blocks: Any, **hparams: Any) -> None:
        assert len(blocks) ^ len(hparams), "Provide either blocks, or hyperparameters!"

        if hparams:
            return
        super().__init__(*blocks, **hparams)

    @classmethod
    def from_hyperparameters(
        cls,
        *,
        input_size: int,
        num_blocks: int = 5,
        # block_cfg: dict = ResNetBlock.HP,
    ):
        r"""Create a ResNet model from hyperparameters."""
        if "input_size" in block_cfg:
            block_cfg["input_size"] = input_size

        blocks: list[nn.Module] = []
        for _ in range(num_blocks):
            module: nn.Module = initialize_from_config(block_cfg)
            blocks.append(module)
        return cls(*blocks)

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        r"""Forward pass.

        Parameters
        ----------
        x: Tensor

        Returns
        -------
        Tensor
        """
        for block in self:
            x = x + block(x)
        return x