In [None]:
from collections.abc import *
from typing import Any, Dict, Final, List, Optional, Tuple

import torch
from torch import Tensor, jit, nn

# Can we serialize models with attached dictionary?

⟹ only when flattened!

In [None]:
def test_model_attribute(model: nn.Module, attr: str) -> None:
    r"""Check whether attr is maintained under torhc JIT"""

    original_model = model

    if hasattr(original_model, attr):
        attribute = getattr(original_model, attr)
        print(f"{original_model}.{attr}={attribute}")
    else:
        print(f"{original_model}.{attr} does not exist")

    serialized_model = jit.script(model)

    if hasattr(serialized_model, attr):
        attribute = getattr(serialized_model, attr)
        print(f"{serialized_model}.{attr}={attribute}")
    else:
        print(f"{serialized_model}.{attr} does not exist")

    jit.save(serialized_model, "model.pt")
    loaded_model = jit.load("model.pt")

    if hasattr(loaded_model, attr):
        attribute = getattr(loaded_model, attr)
        print(f"{loaded_model}.{attr}={attribute}")
    else:
        print(f"{loaded_model}.{attr} does not exist")

In [None]:
from tsdm.utils import flatten_dict

In [None]:
class MyConfig:
    ...

In [None]:
?flatten_dict

In [None]:
def flatten_dict(
    d: dict[str, Iterable[Any]], /, *, recursive: bool = True, how: Callable = "tuple"
) -> dict[tuple[str, ...], Any]:
    r"""Flatten a dictionary containing iterables to a list of tuples.

    Parameters
    ----------
    d: dict
    recursive: bool (default=True)
        If true applies flattening strategy recursively on nested dicts, yielding
        list[tuple[key1, key2, ...., keyN, value]]

    Returns
    -------
    list[tuple[Any, ...]]
    """

    result = {}
    for key, item in d.items():
        if isinstance(key, tuple):
            raise ValueError("Keys are not allowed to be tuples!")
        if isinstance(item, dict) and recursive:
            subdict = flatten_dict(item, recursive=True)
            for subkey, subitem in subdict.items():
                result[(key, subkey)] = subitem
        else:
            result[key] = item
    return result


def unflatten_dict(
    d: dict[Hashable, Iterable[Any]], recursive: bool = True
) -> list[tuple[Any, ...]]:
    r"""Flatten a dictionary containing iterables to a list of tuples.

    Parameters
    ----------
    d: dict
    recursive: bool (default=True)
        If true applies flattening strategy recursively on nested dicts, yielding
        list[tuple[key1, key2, ...., keyN, value]]

    Returns
    -------
    list[tuple[Any, ...]]
    """
    result = {}
    for key, item in d.items():
        if isinstance(key, tuple):
            if key[0] not in result:
                result[key[0]] = {}

            # subdict = {subkey[1:]: subitem for subkey, subitem in item.items()}
            if len(key) <= 1:
                result[key[0]] |= {key[1:]: item}
            else:
                result[key[0]] |= unflatten_dict({key[1:]: item})

        else:
            result[key] = item
    return result

In [None]:
test_dict: dict[str, Any] = {
    "foo": 2,
    "bar": True,
    "baz": 0.99,
    "nested": {
        "foo": 1,
        "bar": True,
        "baz": 0.99,
        "nested": {"foo": 1, "bar": True, "baz": 0.99},
    },
}

In [None]:
flat = flatten_dict(DemoModel.FOO)
flat

In [None]:
unflatten_dict(flat)

In [None]:
class DemoModel(nn.Module):
    FOO: Dict[str, Any] = {
        "a": 2,
        "b": True,
        "c": 0.99,
        "d": {"a": 1, "b": True, "c": 0.99, "d": {"a": 1}},
    }

    BAR: Dict[str, Any]
    # BAZ: Dict[str, Any]

    def __init__(self) -> None:
        self.BAR = flatten_dict(self.FOO)
        super().__init__()
        # self.BAZ = self.FOO

    def forward(self, x: Tensor) -> Tensor:
        r"""Simply identity"""
        return x


test_model_attribute(DemoModel(), "FOO")
test_model_attribute(DemoModel(), "BAR")
test_model_attribute(DemoModel(), "BAZ")

# Observations

- As of **`torch=1.12.1`**, `typing.Dict` works but `dict` doesn't?
- AS of **`torch=1.12.1`**, nested dictionaries are not supported.
- As of **`torch=1.12.1`**, tracing only works if
    - The complete dictionary is added in the class 
    - The dictionary is not annotated.
- It makes no difference whether the dictionary is added before or after `super().__init__` is called


## Idea

We need to do 2 things for a robust initialization

1. Input validation: in particular, cast values to correct type
    - Could use pydantic
        - Does not seem to support `KW_ONLY` yet.
        - Use regular `DataClasses` in the meantime?
        - ~~Alternative `TypedDict`?~~
            - `TypedDict` do not allow extra keys... https://github.com/python/mypy/issues/4617
2. Sub-module compatibility: e.g. module might need to be written such that latent size agrees with output size of other module.
    - If values are given, check for compatibility.
    - If values are not given, use defaults.
        - `if "input_size" in **kwargs:... else ...`
    - Do we want to support mixed inputs? (e.g. `encoder=<some nn.Module>`, `activation=<module_dict>`)
3. Module Creation.
    - Should be the responsibility of `from_hyperparameters` / `__new__` / `__init__`.
4. Every Module should sport a default config
    - Other Modules should be able to use this config, e.g. `encoder=MyEncoderModel.Config`
        - If uninitialized class is given, the initialize with its default dict.
        - If initialized class is given, use it as is.
        - If config / dictionary is given, use it to locate the module and initialize it.
5. Serialization: 
    - The Dataclass / NestedDict should work arbitrary Optional data.
6. Should positional arguments be allowed or only `*args`?
8. Optional cool stuff:
    - Signatures and automatic signature validation.
        - Only makes sense post-initialization.

## Specification

- Need a class object that maps 1:1 to a nested `dict` / `json` / `toml` file.
    - Implement conversion utilities (⇝ pydantic.)
- Should have a fixed set of required values, but allow optional values (with some naming restrictions)
- Certain values (`__name__`, `__qualname__`, `__module__`) should always be included.

How object is created.

- Option 1: Subclass a `BaseConfig` class that implements the `__name__` logic.
    - How to pass `*args`, `**kwargs`?
    ```python
    class MyModel(nn.Module):
        Config(BaseConfig):
            a: int,
            b: int, 
            *args: Any
            droprate: float = 0.2
            **kwargs: Any
    ```

- Option 2: DataClass / TypedDict
- Option 3: Instantiate a class locally? (How to pass Type Hints?)
- Option 4: Class Factory. (might actually be best?) Issue: doesn't work syntactically.
    
    ```python
    class MyModel(nn.Module):
        Config = create_config(
            a: int,
            b: int, 
            *args: Any
            droprate: float = 0.2
            **kwargs: Any
        )
    ```
    

## How to deal with missing values?

```python
class MyModel(nn.Module):
    Config:
        input_size: int
        drop_rate: float = 0.5
```

This Model requires getting input_size as an input. 
- How do we initialize it? 
- How does the default config in dictionary form look like?
    - Should we even be able to serialize it with missing keys?
- What value do we put for missing?


```python
class MyModel:    
    class Config:
        ...
    
Class MyOtherModel:

    class Config:
        encoder: MyModel
```

In [None]:
def foo(*args, **kwargs):
    return args, kwargs

In [None]:
from pydantic import BaseModel


class User(BaseModel):
    id: int
    name = "Jane Doe"

In [None]:
User([], id=2, x=3)

In [None]:
import pydantic

In [None]:
from pydantic import dataclasses as pydantic_dataclasses
from pydantic.dataclasses import dataclass as pydantic_dataclass

In [None]:
@pydantic_dataclass
class Config:
    input_size: int
    output_size: int
    latent_size: Optional[int] = None
    num_layers: int = 2
    activation: str = "relu"

In [None]:
Config.__name__

In [None]:
nn.Module | type[nn.Module] |

In [None]:
ResNet.Config.__qualname__ - ResNet.Config.__name__

In [None]:
class ResNet(nn.ModuleList):
    r"""A ResNet model."""

    @pydantic_dataclass
    class Config:
        input_size: int
        output_size: int
        latent_size: Optional[int] = None
        num_layers: int = 2
        activation: str = "relu"

    HP: Final[dict] = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "input_size": None,
        "num_blocks": 5,
        # "blocks": ResNetBlock.HP,
    }

    def __new__(cls, *blocks, **hparams):
        r"""Initialize from hyperparameters."""
        assert len(blocks) ^ len(hparams), "Provide either blocks, or hyperparameters!"

        if hparams:
            return cls.from_hyperparameters(**hparams)

        return super().__new__(cls)

    def __init__(self, *blocks: Any, **hparams: Any) -> None:
        assert len(blocks) ^ len(hparams), "Provide either blocks, or hyperparameters!"

        if hparams:
            return
        super().__init__(*blocks, **hparams)

    @classmethod
    def from_hyperparameters(
        cls,
        *,
        input_size: int,
        num_blocks: int = 5,
        # block_cfg: dict = ResNetBlock.HP,
    ):
        r"""Create a ResNet model from hyperparameters."""
        if "input_size" in block_cfg:
            block_cfg["input_size"] = input_size

        blocks: list[nn.Module] = []
        for _ in range(num_blocks):
            module: nn.Module = initialize_from_config(block_cfg)
            blocks.append(module)
        return cls(*blocks)

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        r"""Forward pass.

        Parameters
        ----------
        x: Tensor

        Returns
        -------
        Tensor
        """
        for block in self:
            x = x + block(x)
        return x