# DeepSet Implementation

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
from torch import nn, Tensor, jit
import torch
from typing import Any, Final

In [None]:
class DeepSet(nn.Module):
 
    class CONFIG(BASE_CONFIG):
        input_size : int,
        latent_size : int,
        output_size : int,
        encoder : nn.Module = {},
        decoder : nn.Module = {},
        aggregation : nn.Module ={},


    def __init__(self, *args, **kwargs) -> None:
        super().__init__()
        
        CFG = self.CONFIG(*args, **kwargs)
        
#         if isinstance(CFG.encoder, nn.Module):
#             self.encoder = CFG.encoder
#             if hasattr(CFG.encoder, 'HP'):
#                 CFG.encoder = CFG.encoder.HP
#             else:
#                 CFG.encoder = generate_cfg_from_obj(CFG.encoder)
#         else:
#             self.encoder = initialize_from_config(self.CFG.encoder)
            
        self.encoder = CFG.encoder
        self.aggregation = CFG.aggregation
        self.decoder = CFG.decoder
        
        self.CFG = CFG.dict()
        
    def forward(self, x: Tensor) -> Tensor:
        """Signature: [..., <Var>, D] -> [..., F]
        
        Components:
        
          - Encoder: [..., D] -> [..., E]
          - Aggregation: [..., V, E] -> [..., E]
          - Decoder: [..., E] -> [..., F]
        """
        x = self.encoder(x)
        x = self.aggregation(x)
        x = self.decoder(x)
        return x

In [None]:
from einops.layers import ReduceMixin

?ReduceMixin

In [None]:
class Mean(nn.Module)

    keepdim: Final[bool] = True
    dim: Final[list[int]] = []

    def __init__(self) -> None:
        
    
    
    def forward(self, x: Tensor) -> Tensor:
        return 


In [None]:
from dataclasses import dataclass
from typing import get_type_hints


class MyModule(nn.Module):
    @dataclass
    class CONFIG:
        size: int

    def __init__(self, *args, **kwargs):
        super().__init__()
        self.CFG = self.CONFIG(*args, **kwargs)

In [None]:
jit.script(MyModule(3))