# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import gc
import logging
import os
from collections.abc import Callable
from copy import deepcopy
from dataclasses import dataclass
from functools import update_wrapper, wraps
from inspect import Parameter, signature
from time import perf_counter_ns
from types import MethodType
from typing import Any, Final, Optional, Union, overload, TypedDict
from typing import Dict, Any

import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
import logging
from collections import OrderedDict
from typing import Any, TypeVar

import torch
from torch import Tensor, jit, nn

# from tsdm.models.generic.dense import ReverseDense
# from tsdm.utils import deep_dict_update, initialize_from_config
# from tsdm.utils.decorators import trace

In [None]:
import dataclasses
from dataclasses import dataclass, field

import pydantic
from pydantic import BaseModel
from pydantic import dataclasses as pydantic_dataclasses
from pydantic.dataclasses import dataclass as pydantic_dataclass

# from dataclasses import KW_ONLY

In [None]:
@dataclass
class Config:
    input_size: int
    output_size: int
    latent_size: Optional[int] = None
    num_layers: int = 2
    activation: str = "relu"

In [None]:
dataclasses.asdict(Config(2, 3))

In [None]:
@pydantic_dataclass
class Config:
    input_size: int
    output_size: int
    latent_size: Optional[int] = None
    num_layers: int = 2
    activation: str = "relu"

In [None]:
dataclasses.asdict(Config(2, 3))

## Vanilla DataClasses

In [None]:
class MLP(nn.Sequential):
    HP: Dict[str, Any]

    @dataclass
    class Config:
        input_size: int
        output_size: int
        latent_size: Optional[int] = None
        num_layers: int = 2
        activation: str = "relu"

    def __init__(self, *args, **kwargs):
        config = self.Config(*args, **kwargs)

        if config.latent_size is None:
            config.latent_size = (config.input_size + config.output_size) // 2

        self.HP = dataclasses.asdict(config)

        layers = [nn.Linear(config.input_size, config.latent_size)]

        for _ in range(config.num_layers):
            layers.append(nn.Linear(config.latent_size, config.latent_size))

        layers.append(nn.Linear(config.latent_size, config.output_size))

        super().__init__(*layers)

In [None]:
model = MLP(2, 3)
x = torch.randn(7, 2)
model(x)
scripted = jit.script(model)
scripted(x)
jit.save(scripted, "model")
model = jit.load("model")
model.HP

## Pydantic DataClasses

In [None]:
class MLP(nn.Sequential):
    HP: Dict[str, Any]

    @pydantic_dataclass
    class Config:
        input_size: int
        output_size: int
        latent_size: Optional[int] = None
        num_layers: int = 2
        activation: str = "relu"

    def __init__(self, *args, **kwargs):
        config = self.Config(*args, **kwargs)

        if config.latent_size is None:
            config.latent_size = (config.input_size + config.output_size) // 2

        self.HP = dataclasses.asdict(config)

        layers = [nn.Linear(config.input_size, config.latent_size)]

        for _ in range(config.num_layers):
            layers.append(nn.Linear(config.latent_size, config.latent_size))

        layers.append(nn.Linear(config.latent_size, config.output_size))

        super().__init__(*layers)

In [None]:
model = MLP(2, 3)
x = torch.randn(7, 2)
model(x)
scripted = jit.script(model)
scripted(x)
jit.save(scripted, "model")
model = jit.load("model")
model.HP

## Pydantic BaseModel

In [None]:
class MLP(nn.Sequential):
    HP: Dict[str, Any]

    class Config(BaseModel):
        input_size: int
        output_size: int
        latent_size: Optional[int] = None
        num_layers: int = 2
        activation: str = "relu"

    def __init__(self, *args, **kwargs):
        config = self.Config(*args, **kwargs)

        if config.latent_size is None:
            config.latent_size = (config.input_size + config.output_size) // 2

        self.HP = dataclasses.asdict(config)

        layers = [nn.Linear(config.input_size, config.latent_size)]

        for _ in range(config.num_layers):
            layers.append(nn.Linear(config.latent_size, config.latent_size))

        layers.append(nn.Linear(config.latent_size, config.output_size))

        super().__init__(*layers)

# Nested Usage 

In [None]:
dataclasses.MISSING

In [None]:
from typing import TypeVar

In [None]:
@dataclass
class Config:
    input_size: int
    output_size: int
    latent_size: int

    def __post_init__(self):
        if self.latent_size is Any:
            self.latent_size = self.input_size

In [None]:
Config(2, 3)

In [None]:
@dataclass
class Config:
    input_size: int
    output_size: int
    latent_size: int = Any

    def __post_init__(self):
        if self.latent_size is Any:
            self.latent_size = self.input_size


conf = Config(2, 3, latent_size=4)

In [None]:
Config(2, 3)

In [None]:
class Deepset(nn.Sequential):
    HP: Dict[str, Any]

    @dataclass
    class Config:
        input_size: int
        output_size: int
        latent_size: Optional[int] = None
        encoder:
        decoder:
        
    
    def __init__(self, *args, **kwargs) -> None:
        config = self.Config(*args, **kwargs)
        
        if config.latent_size is None:
            config.latent_size = (config.input_size + config.output_size) // 2

        self.HP = dataclasses.asdict(config)

    
    
    