In [24]:
from torchmeta.modules import MetaModule,MetaSequential, MetaLinear
from collections import OrderedDict
import torch.nn as nn
from torchmeta.modules.utils import get_subdict

In [16]:
class MetaMLPModel(MetaModule):
    """Multi-layer Perceptron architecture from [1].

    Parameters
    ----------
    in_features : int
        Number of input features.

    out_features : int
        Number of classes (output of the model).

    hidden_sizes : list of int
        Size of the intermediate representations. The length of this list
        corresponds to the number of hidden layers.

    References
    ----------
    .. [1] Finn C., Abbeel P., and Levine, S. (2017). Model-Agnostic Meta-Learning
           for Fast Adaptation of Deep Networks. International Conference on
           Machine Learning (ICML) (https://arxiv.org/abs/1703.03400)
    """
    def __init__(self, in_features, out_features, hidden_sizes):
        super(MetaMLPModel, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_sizes = hidden_sizes

        layer_sizes = [in_features] + hidden_sizes
        self.features = MetaSequential(OrderedDict([('layer{0}'.format(i + 1),
            MetaSequential(OrderedDict([
                ('linear', MetaLinear(hidden_size, layer_sizes[i + 1], bias=True)),
                ('relu', nn.ReLU())
            ]))) for (i, hidden_size) in enumerate(layer_sizes[:-1])]))
        self.classifier = MetaLinear(hidden_sizes[-1], out_features, bias=True)

    def forward(self, inputs, params=None):
        features = self.features(inputs, params=get_subdict(params, 'features'))
        logits = self.classifier(features, params=get_subdict(params, 'classifier'))
        return logits


def ModelMLPSinusoid(hidden_sizes=[80, 80]):
    return MetaMLPModel(25, 1, hidden_sizes)

In [19]:
model=ModelMLPSinusoid()

In [21]:
import torch

In [25]:
model(torch.randn((4,25,25)))

tensor([[[0.2209],
         [0.2337],
         [0.0954],
         [0.2412],
         [0.1335],
         [0.1831],
         [0.1085],
         [0.1577],
         [0.1451],
         [0.1439],
         [0.1881],
         [0.1448],
         [0.2303],
         [0.0690],
         [0.0637],
         [0.2668],
         [0.1589],
         [0.1906],
         [0.1518],
         [0.1583],
         [0.1946],
         [0.2185],
         [0.2122],
         [0.1000],
         [0.1500]],

        [[0.1421],
         [0.1850],
         [0.1602],
         [0.1968],
         [0.2064],
         [0.1478],
         [0.1902],
         [0.0652],
         [0.0474],
         [0.2174],
         [0.0917],
         [0.1739],
         [0.2585],
         [0.0379],
         [0.1374],
         [0.0895],
         [0.1953],
         [0.0492],
         [0.2703],
         [0.1331],
         [0.1900],
         [0.2029],
         [0.0366],
         [0.1524],
         [0.1896]],

        [[0.2917],
         [0.1042],
        

In [None]:
model