In [None]:
import torch
import torch.nn as nn
from nemo.core import NeuralModule, ModelPT
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

class MLP(torch.nn.Module):

    def __init__(self, dim: int=50):
        super().__init__()

        self.fc = torch.nn.Linear(dim, dim)
        self.ln = torch.LayerNorm(dim)

    def forward(self, x):
        x = self.fc(x)
        x = self.ln(x)
        return x

class ResidualMLP(torch.nn.Module):

    def __init__(self, dim: int, num_layers:int):
        super().__init__()

        self.dim = dim
        self.num_layers = num_layers
        self.layers = nn.ModuleList([MLP(dim) for _ in range(num_layers)])

    def forward(self, x):
        input = x
        for layer in self.layers:
            x = layers(x)
            x = x + input
            input = x
        return x


In [None]:
class SimpleModel(ModelPT):

    def __init__(self, cfg, trainer=None):
        super().__init__(cfg, trainer=trainer)

        self.encoder = instantiate(cfg.encoder) # Type ResidualMLP
        self.decoder = instantiate(cfg.decoder) # Type ResidualMLP
        self.projection = torch.nn.Linear(self.decoder.dim, cfg.out_features)

    def forward(self, x):
        y = self.encoder(x)
        z = self.decoder(y)
        out = self.projection(z)
        return out

    def list_available_models(cls):
        return []

    def setup_training_data(train_data_config):
        pass

    def setup_validation_data(val_data_config):
        pass


In [None]:
def get_classpath(cls):
    return f'{cls.__module__}.{cls.__name__}'

def get_model_config(dim=512):
    config = OmegaConf.create(
        {
            'in_features': dim,
            'out_features': 10,
            'encoder': {'_target_': get_classpath(ResidualMLP), 'dim': dim, 'num_layers': 4},
            'decoder': {'_target_': get_classpath(ResidualMLP), 'dim': dim, 'num_layers': 2},
        }
    )
    return config

dim = 512
model_cfg = get_model_config(dim)
model = SimpleModel(model_cfg)
model.summarize()

# Check if the forward pass works !
with torch.no_grad():
  input_data = torch.randn(8, dim)
  out = model(input_data)
  print(out.shape)

In [None]:
from nemo.core import adapter_mixins
help(adapter_mixins.AdapterModuleMixin)

class MLP