In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle
import sys


sys.path.insert(1, "..")

from ts_dataset import TSDataset
from base_models import LSTMModel
from metrics import torch_mae as mae


In [None]:
from collections import OrderedDict

import torch
import torch.nn.functional as F

from maml.models.model import Model


def weight_init(module):
    if (isinstance(module, torch.nn.Linear)
        or isinstance(module, torch.nn.Conv2d)):
        torch.nn.init.xavier_uniform_(module.weight)
        module.bias.data.zero_()


class ConvModel(Model):
    """
    NOTE: difference to tf implementation: batch norm scaling is enabled here
    TODO: enable 'non-transductive' setting as per
          https://arxiv.org/abs/1803.02999
    """
    def __init__(self, input_channels, output_size, num_channels=64,
                 kernel_size=3, padding=1, nonlinearity=F.relu,
                 use_max_pool=False, img_side_len=28, verbose=False):
        super(ConvModel, self).__init__()
        self._input_channels = input_channels
        self._output_size = output_size
        self._num_channels = num_channels
        self._kernel_size = kernel_size
        self._nonlinearity = nonlinearity
        self._use_max_pool = use_max_pool
        self._padding = padding
        self._bn_affine = False
        self._reuse = False
        self._verbose = verbose

        if self._use_max_pool:
            self._conv_stride = 1
            self._features_size = 1
            self.features = torch.nn.Sequential(OrderedDict([
                ('layer1_conv', torch.nn.Conv2d(self._input_channels,
                                                self._num_channels,
                                                self._kernel_size,
                                                stride=self._conv_stride,
                                                padding=self._padding)),
                ('layer1_bn', torch.nn.BatchNorm2d(self._num_channels,
                                                   affine=self._bn_affine,
                                                   momentum=0.001)),
                ('layer1_max_pool', torch.nn.MaxPool2d(kernel_size=2,
                                                       stride=2)),
                ('layer1_relu', torch.nn.ReLU(inplace=True)),
                ('layer2_conv', torch.nn.Conv2d(self._num_channels,
                                                self._num_channels*2,
                                                self._kernel_size,
                                                stride=self._conv_stride,
                                                padding=self._padding)),
                ('layer2_bn', torch.nn.BatchNorm2d(self._num_channels*2,
                                                   affine=self._bn_affine,
                                                   momentum=0.001)),
                ('layer2_max_pool', torch.nn.MaxPool2d(kernel_size=2,
                                                       stride=2)),
                ('layer2_relu', torch.nn.ReLU(inplace=True)),
                ('layer3_conv', torch.nn.Conv2d(self._num_channels*2,
                                                self._num_channels*4,
                                                self._kernel_size,
                                                stride=self._conv_stride,
                                                padding=self._padding)),
                ('layer3_bn', torch.nn.BatchNorm2d(self._num_channels*4,
                                                   affine=self._bn_affine,
                                                   momentum=0.001)),
                ('layer3_max_pool', torch.nn.MaxPool2d(kernel_size=2,
                                                       stride=2)),
                ('layer3_relu', torch.nn.ReLU(inplace=True)),
                ('layer4_conv', torch.nn.Conv2d(self._num_channels*4,
                                                self._num_channels*8,
                                                self._kernel_size,
                                                stride=self._conv_stride,
                                                padding=self._padding)),
                ('layer4_bn', torch.nn.BatchNorm2d(self._num_channels*8,
                                                   affine=self._bn_affine,
                                                   momentum=0.001)),
                ('layer4_max_pool', torch.nn.MaxPool2d(kernel_size=2,
                                                       stride=2)),
                ('layer4_relu', torch.nn.ReLU(inplace=True)),
            ]))
        else:
            self._conv_stride = 2
            self._features_size = (img_side_len // 14)**2
            self.features = torch.nn.Sequential(OrderedDict([
                ('layer1_conv', torch.nn.Conv2d(self._input_channels,
                                                self._num_channels,
                                                self._kernel_size,
                                                stride=self._conv_stride,
                                                padding=self._padding)),
                ('layer1_bn', torch.nn.BatchNorm2d(self._num_channels,
                                                   affine=self._bn_affine,
                                                   momentum=0.001)),
                ('layer1_relu', torch.nn.ReLU(inplace=True)),
                ('layer2_conv', torch.nn.Conv2d(self._num_channels,
                                                self._num_channels*2,
                                                self._kernel_size,
                                                stride=self._conv_stride,
                                                padding=self._padding)),
                ('layer2_bn', torch.nn.BatchNorm2d(self._num_channels*2,
                                                   affine=self._bn_affine,
                                                   momentum=0.001)),
                ('layer2_relu', torch.nn.ReLU(inplace=True)),
                ('layer3_conv', torch.nn.Conv2d(self._num_channels*2,
                                                self._num_channels*4,
                                                self._kernel_size,
                                                stride=self._conv_stride,
                                                padding=self._padding)),
                ('layer3_bn', torch.nn.BatchNorm2d(self._num_channels*4,
                                                   affine=self._bn_affine,
                                                   momentum=0.001)),
                ('layer3_relu', torch.nn.ReLU(inplace=True)),
                ('layer4_conv', torch.nn.Conv2d(self._num_channels*4,
                                                self._num_channels*8,
                                                self._kernel_size,
                                                stride=self._conv_stride,
                                                padding=self._padding)),
                ('layer4_bn', torch.nn.BatchNorm2d(self._num_channels*8,
                                                   affine=self._bn_affine,
                                                   momentum=0.001)),
                ('layer4_relu', torch.nn.ReLU(inplace=True)),
            ]))

        self.classifier = torch.nn.Sequential(OrderedDict([
            ('fully_connected', torch.nn.Linear(self._num_channels*8,
                                                self._output_size))
        ]))
        self.apply(weight_init)

    def forward(self, task, params=None, embeddings=None):
        if not self._reuse and self._verbose: print('='*10 + ' Model ' + '='*10)
        if params is None:
            params = OrderedDict(self.named_parameters())

        x = task.x
        if not self._reuse and self._verbose: print('input size: {}'.format(x.size()))
        for layer_name, layer in self.features.named_children():
            weight = params.get('features.' + layer_name + '.weight', None)
            bias = params.get('features.' + layer_name + '.bias', None)
            if 'conv' in layer_name:
                x = F.conv2d(x, weight=weight, bias=bias,
                             stride=self._conv_stride, padding=self._padding)
            elif 'bn' in layer_name:
                x = F.batch_norm(x, weight=weight, bias=bias,
                                 running_mean=layer.running_mean,
                                 running_var=layer.running_var,
                                 training=True)
            elif 'max_pool' in layer_name:
                x = F.max_pool2d(x, kernel_size=2, stride=2)
            elif 'relu' in layer_name:
                x = F.relu(x)
            elif 'fully_connected' in layer_name:
                break
            else:
                raise ValueError('Unrecognized layer {}'.format(layer_name))
            if not self._reuse and self._verbose: print('{}: {}'.format(layer_name, x.size()))

        # in maml network the conv maps are average pooled
        x = x.view(x.size(0), self._num_channels*8, self._features_size)
        if not self._reuse and self._verbose: print('reshape to: {}'.format(x.size()))
        x = torch.mean(x, dim=2)
        if not self._reuse and self._verbose: print('reduce mean: {}'.format(x.size()))
        logits = F.linear(
            x, weight=params['classifier.fully_connected.weight'],
            bias=params['classifier.fully_connected.bias'])
        if not self._reuse and self._verbose: print('logits size: {}'.format(logits.size()))
        if not self._reuse and self._verbose: print('='*27)
        self._reuse = True
        return logits


In [None]:
class LSTM(RNNBase):
    r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
    sequence.


    For each element in the input sequence, each layer computes the following
    function:

    .. math::
        \begin{array}{ll} \\
            i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
            f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
            g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
            o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
            c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
            h_t = o_t \odot \tanh(c_t) \\
        \end{array}

    where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
    state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
    is the hidden state of the layer at time `t-1` or the initial hidden
    state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
    :math:`o_t` are the input, forget, cell, and output gates, respectively.
    :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.

    In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
    (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
    dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
    variable which is :math:`0` with probability :attr:`dropout`.

    Args:
        input_size: The number of expected features in the input `x`
        hidden_size: The number of features in the hidden state `h`
        num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
            would mean stacking two LSTMs together to form a `stacked LSTM`,
            with the second LSTM taking in outputs of the first LSTM and
            computing the final results. Default: 1
        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
            Default: ``True``
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False``
        dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
            LSTM layer except the last layer, with dropout probability equal to
            :attr:`dropout`. Default: 0
        bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``

    Inputs: input, (h_0, c_0)
        - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
          of the input sequence.
          The input can also be a packed variable length sequence.
          See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
          :func:`torch.nn.utils.rnn.pack_sequence` for details.
        - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
          containing the initial hidden state for each element in the batch.
          If the LSTM is bidirectional, num_directions should be 2, else it should be 1.
        - **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
          containing the initial cell state for each element in the batch.

          If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.


    Outputs: output, (h_n, c_n)
        - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
          containing the output features `(h_t)` from the last layer of the LSTM,
          for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
          given as the input, the output will also be a packed sequence.

          For the unpacked case, the directions can be separated
          using ``output.view(seq_len, batch, num_directions, hidden_size)``,
          with forward and backward being direction `0` and `1` respectively.
          Similarly, the directions can be separated in the packed case.
        - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
          containing the hidden state for `t = seq_len`.

          Like *output*, the layers can be separated using
          ``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*.
        - **c_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
          containing the cell state for `t = seq_len`.

    Attributes:
        weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
            `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
            Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`
        weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
            `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`
        bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
            `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
        bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
            `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`

    .. note::
        All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
        where :math:`k = \frac{1}{\text{hidden\_size}}`

    .. include:: ../cudnn_persistent_rnn.rst

    Examples::

        >>> rnn = nn.LSTM(10, 20, 2)
        >>> input = torch.randn(5, 3, 10)
        >>> h0 = torch.randn(2, 3, 20)
        >>> c0 = torch.randn(2, 3, 20)
        >>> output, (hn, cn) = rnn(input, (h0, c0))
    """

    def __init__(self, *args, **kwargs):
        super(LSTM, self).__init__('LSTM', *args, **kwargs)

    def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]):
        self.check_input(input, batch_sizes)
        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)

        self.check_hidden_size(hidden[0], expected_hidden_size,
                               'Expected hidden[0] size {}, got {}')
        self.check_hidden_size(hidden[1], expected_hidden_size,
                               'Expected hidden[1] size {}, got {}')

    def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
        if permutation is None:
            return hx
        return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)

    @overload
    @torch._jit_internal._overload_method  # noqa: F811
    def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
                ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:  # noqa: F811
        pass

    @overload
    @torch._jit_internal._overload_method  # noqa: F811
    def forward(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
                ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:  # noqa: F811
        pass

    def forward(self, input, hx=None):  # noqa: F811
        orig_input = input
        # xxx: isinstance check needs to be in conditional for TorchScript to compile
        if isinstance(orig_input, PackedSequence):
            input, batch_sizes, sorted_indices, unsorted_indices = input
            max_batch_size = batch_sizes[0]
            max_batch_size = int(max_batch_size)
        else:
            batch_sizes = None
            max_batch_size = input.size(0) if self.batch_first else input.size(1)
            sorted_indices = None
            unsorted_indices = None

        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            zeros = torch.zeros(self.num_layers * num_directions,
                                max_batch_size, self.hidden_size,
                                dtype=input.dtype, device=input.device)
            hx = (zeros, zeros)
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)
        if batch_sizes is None:
            result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
                              self.dropout, self.training, self.bidirectional, self.batch_first)
        else:
            result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
                              self.num_layers, self.dropout, self.training, self.bidirectional)
        output = result[0]
        hidden = result[1:]
        # xxx: isinstance check needs to be in conditional for TorchScript to compile
        if isinstance(orig_input, PackedSequence):
            output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
            return output_packed, self.permute_hidden(hidden, unsorted_indices)
        else:
            return output, self.permute_hidden(hidden, unsorted_indices)



In [2]:
from torch.nn import _VF


In [4]:
dataset_name = "HR"
dataset_name = "POLLUTION"
model_name = "LSTM"

task_size = 50
batch_size = 64
output_dim = 1

batch_size = 20
horizon = 10
meta_learning_rate = 10e-6
learning_rate = 10e-5
n_inner_iter = 1
##test

if dataset_name == "HR":
    window_size = 32
    input_dim = 13
elif dataset_name == "POLLUTION":
    window_size = 5
    input_dim = 14

model = LSTMModel( batch_size=batch_size, seq_len = window_size, input_dim = input_dim, n_layers = 2, hidden_dim = 120, output_dim =1)


In [11]:
class ExtendedLSTMModel(LSTMModel):
    
    def __init__(self, *args, **kwargs):
        super(ExtendedLSTMModel, self).__init__( *args, **kwargs)  
    
    def get_flat_weights(self):
        
        self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn)] 
        return self._flat_weights

In [12]:
model = ExtendedLSTMModel( batch_size=batch_size, seq_len = window_size, input_dim = input_dim, n_layers = 2, hidden_dim = 120, output_dim =1)


In [14]:
model.get_flat_weights()

NameError: name 'wn' is not defined

In [26]:
class LSTMModel(nn.LSTM):
    def __init__(self, *args, **kwargs):
        super(LSTMModel, self).__init__( *args, **kwargs)  
        
    def forward(self, input, params = None, hx=None, embeddings = None):  # noqa: F811
        
            if params is None:
                params = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn)] 
                
            
            orig_input = input
            # xxx: isinstance check needs to be in conditional for TorchScript to compile
            if isinstance(orig_input, PackedSequence):
                input, batch_sizes, sorted_indices, unsorted_indices = input
                max_batch_size = batch_sizes[0]
                max_batch_size = int(max_batch_size)
            else:
                batch_sizes = None
                max_batch_size = input.size(0) if self.batch_first else input.size(1)
                sorted_indices = None
                unsorted_indices = None

            if hx is None:
                num_directions = 2 if self.bidirectional else 1
                zeros = torch.zeros(self.num_layers * num_directions,
                                    max_batch_size, self.hidden_size,
                                    dtype=input.dtype, device=input.device)
                hx = (zeros, zeros)
            else:
                # Each batch of the hidden state should match the input sequence that
                # the user believes he/she is passing in.
                hx = self.permute_hidden(hx, sorted_indices)

            self.check_forward_args(input, hx, batch_sizes)
            if batch_sizes is None:
                result = _VF.lstm(input, hx, params, bias, self.num_layers,
                                  self.dropout, self.training, self.bidirectional, self.batch_first)
            else:
                result = _VF.lstm(input, batch_sizes, hx, params, bias,
                                  self.num_layers, self.dropout, self.training, self.bidirectional)
            output = result[0]
            hidden = result[1:]
            # xxx: isinstance check needs to be in conditional for TorchScript to compile
            if isinstance(orig_input, PackedSequence):
                output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
                return output_packed, self.permute_hidden(hidden, unsorted_indices)
            else:
                return output, self.permute_hidden(hidden, unsorted_indices)

    def get_flat_weights(self):
        
        self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
        return self._flat_weights
    
    def set_weights_names(self, names):
        self._flat_weights_names = names

In [27]:
model = LSTMModel(10,10)

In [28]:
from collections import OrderedDict


In [29]:
parameters = OrderedDict(model.named_parameters())

parameters_names = parameters.keys()
model.set_weights_names(parameters_names)
model.get_flat_weights()

[Parameter containing:
 tensor([[-0.1864, -0.1020, -0.0542, -0.2225, -0.1711, -0.2549,  0.2816, -0.0213,
           0.1339, -0.2140],
         [-0.0253, -0.0252, -0.2634, -0.2152, -0.2168, -0.1847,  0.2027, -0.1384,
           0.0336, -0.1987],
         [-0.2594,  0.0354,  0.2068, -0.0355, -0.0759,  0.1903,  0.2200, -0.0020,
          -0.2667, -0.1872],
         [ 0.2406,  0.1528,  0.0144,  0.2642, -0.0766, -0.2302,  0.1422,  0.1423,
           0.1520,  0.1857],
         [-0.2977, -0.2405, -0.1424, -0.2855,  0.3041, -0.0335,  0.2057, -0.1081,
           0.0950, -0.1421],
         [ 0.2610,  0.1224, -0.1570, -0.1070, -0.1861, -0.2650,  0.2597, -0.1921,
           0.0733, -0.0582],
         [ 0.0858,  0.2482,  0.3040,  0.1830,  0.0109, -0.0207, -0.0294, -0.0480,
           0.2048, -0.0468],
         [-0.2926, -0.3000, -0.1401, -0.0004, -0.0733, -0.1035,  0.1918,  0.1821,
           0.0571,  0.2439],
         [-0.1654, -0.2939, -0.0889, -0.0926, -0.0903,  0.0923, -0.0872,  0.0194,
       

In [40]:
a = model.get_flat_weights()
a[0]

Parameter containing:
tensor([[-0.1864, -0.1020, -0.0542, -0.2225, -0.1711, -0.2549,  0.2816, -0.0213,
          0.1339, -0.2140],
        [-0.0253, -0.0252, -0.2634, -0.2152, -0.2168, -0.1847,  0.2027, -0.1384,
          0.0336, -0.1987],
        [-0.2594,  0.0354,  0.2068, -0.0355, -0.0759,  0.1903,  0.2200, -0.0020,
         -0.2667, -0.1872],
        [ 0.2406,  0.1528,  0.0144,  0.2642, -0.0766, -0.2302,  0.1422,  0.1423,
          0.1520,  0.1857],
        [-0.2977, -0.2405, -0.1424, -0.2855,  0.3041, -0.0335,  0.2057, -0.1081,
          0.0950, -0.1421],
        [ 0.2610,  0.1224, -0.1570, -0.1070, -0.1861, -0.2650,  0.2597, -0.1921,
          0.0733, -0.0582],
        [ 0.0858,  0.2482,  0.3040,  0.1830,  0.0109, -0.0207, -0.0294, -0.0480,
          0.2048, -0.0468],
        [-0.2926, -0.3000, -0.1401, -0.0004, -0.0733, -0.1035,  0.1918,  0.1821,
          0.0571,  0.2439],
        [-0.1654, -0.2939, -0.0889, -0.0926, -0.0903,  0.0923, -0.0872,  0.0194,
         -0.1670,  0.1744

In [41]:
a = OrderedDict(model.named_parameters())
type(a.get("weight_ih_l0"))

torch.nn.parameter.Parameter

In [30]:
def update_params(self, loss, params):
    """Apply one step of gradient descent on the loss function `loss`,
    with step-size `self._fast_lr`, and returns the updated parameters.
    """
    create_graph = not self._first_order
    grads = torch.autograd.grad(loss, params.values(),
                                create_graph=create_graph, allow_unused=True)
    for (name, param), grad in zip(params.items(), grads):
        if self._inner_loop_grad_clip > 0 and grad is not None:
            grad = grad.clamp(min=-self._inner_loop_grad_clip,
                              max=self._inner_loop_grad_clip)
        if grad is not None:
            params[name] = param - self._fast_lr * grad

    return params

In [31]:

def step(self, adapted_params_list, embeddings_list, val_tasks,
         is_training):
    for optimizer in self._optimizers:
        optimizer.zero_grad()
    post_update_losses = []

    for adapted_params, embeddings, task in zip(
            adapted_params_list, embeddings_list, val_tasks):
        preds = self._model(task, params=adapted_params,
                            embeddings=embeddings)
        loss = self._loss_func(preds, task.y)
        post_update_losses.append(loss)
        self._update_measurements(task, loss, preds)

    mean_loss = torch.mean(torch.stack(post_update_losses))
    if is_training:
        mean_loss.backward()
        if self._alternating:
            self._optimizers[self._alternating_index].step()
            self._alternating_count += 1
            if self._alternating_count % self._alternating_schedules[self._alternating_index] == 0:
                self._alternating_index = (1 - self._alternating_index)
                self._alternating_count = 0
        else:
            self._optimizers[0].step()
            if len(self._optimizers) > 1:
                if self._embedding_grad_clip > 0:
                    _grad_norm = clip_grad_norm_(self._embedding_model.parameters(), self._embedding_grad_clip)
                else:
                    _grad_norm = get_grad_norm(self._embedding_model.parameters())
                    # grad_norm
                    self._grads_mean.append(_grad_norm)
                    self._optimizers[1].step()

    measurements = self._pop_measurements()
    return measurements

In [32]:
def adapt(self, train_tasks):
    adapted_params = []
    embeddings_list = []

    for task in train_tasks:
        params = self._model.param_dict
        embeddings = None
        if self._embedding_model:
            embeddings = self._embedding_model(task)
        for i in range(self._num_updates):
            preds = self._model(task, params=params, embeddings=embeddings)
            loss = self._loss_func(preds, task.y)
            params = self.update_params(loss, params=params)
            if i == 0:
                self._update_measurements(task, loss, preds)
        adapted_params.append(params)
        embeddings_list.append(embeddings)

    measurements = self._pop_measurements()
    return measurements, adapted_params, embeddings_list

In [46]:

def get_grad_norm(parameters, norm_type=2):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    norm_type = float(norm_type)
    total_norm = 0
    for p in parameters:
        param_norm = p.grad.data.norm(norm_type)
        total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1. / norm_type)

    return total_norm




class CustomLSTM(nn.LSTM):
    def __init__(self, *args, **kwargs):
        super(CustomLSTM, self).__init__( *args, **kwargs)  
        
    def forward(self, input, params = None, hx=None, embeddings = None):  # noqa: F811
        
            if params is None:
                params = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn)] 
                
            
            orig_input = input
            # xxx: isinstance check needs to be in conditional for TorchScript to compile
            if isinstance(orig_input, PackedSequence):
                input, batch_sizes, sorted_indices, unsorted_indices = input
                max_batch_size = batch_sizes[0]
                max_batch_size = int(max_batch_size)
            else:
                batch_sizes = None
                max_batch_size = input.size(0) if self.batch_first else input.size(1)
                sorted_indices = None
                unsorted_indices = None

            if hx is None:
                num_directions = 2 if self.bidirectional else 1
                zeros = torch.zeros(self.num_layers * num_directions,
                                    max_batch_size, self.hidden_size,
                                    dtype=input.dtype, device=input.device)
                hx = (zeros, zeros)
            else:
                # Each batch of the hidden state should match the input sequence that
                # the user believes he/she is passing in.
                hx = self.permute_hidden(hx, sorted_indices)

            self.check_forward_args(input, hx, batch_sizes)
            if batch_sizes is None:
                result = _VF.lstm(input, hx, params, bias, self.num_layers,
                                  self.dropout, self.training, self.bidirectional, self.batch_first)
            else:
                result = _VF.lstm(input, batch_sizes, hx, params, bias,
                                  self.num_layers, self.dropout, self.training, self.bidirectional)
            output = result[0]
            hidden = result[1:]
            # xxx: isinstance check needs to be in conditional for TorchScript to compile
            if isinstance(orig_input, PackedSequence):
                output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
                return output_packed, self.permute_hidden(hidden, unsorted_indices)
            else:
                return output, self.permute_hidden(hidden, unsorted_indices)

    def get_flat_weights(self):
        
        self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
        return self._flat_weights
    
    def set_weights_names(self, names):
        self._flat_weights_names = names

class LSTMModel(nn.Module):
    
    def __init__(self, batch_size, seq_len, input_dim, n_layers, hidden_dim, output_dim, lin_hidden_dim = 100):
        super(LSTMModel, self).__init__()

        #self.lstm = nn.CustomLSTM(input_dim, hidden_dim, n_layers, batch_first=True)
        #self.linear = nn.Linear(hidden_dim, output_dim)#
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.n_layers = n_layers
        #self.hidden = self.init_hidden()
        self.input_dim = input_dim
        self.features = torch.nn.Sequential(OrderedDict([
            ("lstm",  nn.CustomLSTM(input_dim, hidden_dim, n_layers, batch_first=True)),
            ("linear", nn.Linear(hidden_dim, output_dim))]))
        
    def init_hidden(self):
        # This is what we'll initialise our hidden state as
        return (torch.zeros(self.n_layers, self.batch_size, self.hidden_dim),
                torch.zeros(self.n_layers, self.batch_size, self.hidden_dim))
        
    def forward(self, x, params):
        
        if params is None:
            params = OrderedDict(self.named_parameters())

        return x

In [47]:
model = LSTMModel( batch_size=batch_size, seq_len = window_size, input_dim = input_dim, n_layers = 2, hidden_dim = 120, output_dim =1)

AttributeError: module 'torch.nn' has no attribute 'CustomLSTM'