In [2]:
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from typing import Optional, Any, Union, Callable, Tuple
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.modules.linear import Linear
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.normalization import LayerNorm
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.overrides import (
    has_torch_function, has_torch_function_unary, has_torch_function_variadic,
    handle_torch_function)
from torch.nn.functional import _mha_shape_check, _in_projection_packed, softmax, dropout, linear
from torch.nn.modules.transformer import _get_clones
import math, copy

%matplotlib inline 
from matplotlib import pyplot as plt
import gym
import gym_sokoban
import numpy as np
from torchbeast.torchbeast.atari_wrappers import *
from torchbeast.torchbeast import atari_wrappers
from gym.wrappers import TimeLimit
from torchbeast.torchbeast.resnet import ResNet, BasicBlock


In [11]:
class TransformerEncoderLayer(Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to attention and feedforward
            operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)

    Alternatively, when ``batch_first`` is ``True``:
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> src = torch.rand(32, 10, 512)
        >>> out = encoder_layer(src)

   
    """
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm_first = norm_first
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        # Legacy string support for activation function.
        if isinstance(activation, str):
            activation = _get_activation_fn(activation)

        # We can't test self.activation in forward() in TorchScript,
        # so stash some information about it instead.
        if activation is F.relu:
            self.activation_relu_or_gelu = 1
        elif activation is F.gelu:
            self.activation_relu_or_gelu = 2
        else:
            self.activation_relu_or_gelu = 0
        self.activation = activation

    def __setstate__(self, state):
        super(TransformerEncoderLayer, self).__setstate__(state)
        if not hasattr(self, 'activation'):
            self.activation = F.relu


    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                concat_k: Optional[Tensor] = None,
                concat_v: Optional[Tensor] = None,) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
       
        x = src
        if self.norm_first:
            attn, k, v = self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, concat_k, concat_v)
            x = x + attn
            x = x + self._ff_block(self.norm2(x))
        else:
            attn, k, v = self._sa_block(x, src_mask, src_key_padding_mask, concat_k, concat_v)
            x = self.norm1(x + attn)
            x = self.norm2(x + self._ff_block(x))
        return x, k, v

    # self-attention block
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor],
                  concat_k: Optional[Tensor], concat_v: Optional[Tensor]) -> Tensor:
        x, _, k, v = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False,
                           concat_k=concat_k,
                           concat_v=concat_v)
        return self.dropout1(x), k, v

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)

class MultiheadAttention(Module):
    r"""Allows the model to jointly attend to information
    from different representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    Multi-Head Attention is defined as:

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

    where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.

    Args:
        embed_dim: Total dimension of the model.
        num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
            across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
        dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
        bias: If specified, adds bias to input / output projection layers. Default: ``True``.
        add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
        add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
            Default: ``False``.
        kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
        vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).

    Examples::

        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)

    """
    __constants__ = ['batch_first']
    bias_k: Optional[torch.Tensor]
    bias_v: Optional[torch.Tensor]

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
                 kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
            self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
            self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True

        super(MultiheadAttention, self).__setstate__(state)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
                need_weights: bool = True, attn_mask: Optional[Tensor] = None,
                average_attn_weights: bool = True, concat_k: Optional[Tensor] = None,
                concat_v: Optional[Tensor] = None,) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
    Args:
        query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
            or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
            :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
            Queries are compared against key-value pairs to produce the output.
            See "Attention Is All You Need" for more details.
        key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
            or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
            :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
            See "Attention Is All You Need" for more details.
        value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
            ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
            sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
            See "Attention Is All You Need" for more details.
        key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
            to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
            Binary and byte masks are supported.
            For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
            the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
            value will be ignored.
        need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
            Default: ``True``.
        attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
            :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
            :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
            broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
            Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
            corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
            corresponding position is not allowed to attend. For a float mask, the mask values will be added to
            the attention weight.
        average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
            heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
            effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)

    Outputs:
        - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
          :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
          where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
          embedding dimension ``embed_dim``.
        - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
          returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
          :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
          :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per
          head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.

        .. note::
            `batch_first` argument is ignored for unbatched inputs.
        """
        is_batched = query.dim() == 3
        
        any_nested = query.is_nested or key.is_nested or value.is_nested
        assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
                                f"The fast path was not hit because {why_not_fast_path}")

        if self.batch_first and is_batched:
            # make sure that the transpose op does not affect the "is" property
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = [x.transpose(1, 0) for x in (query, key)]
                    value = key
            else:
                query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights, k, v = multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights,
                concat_k=concat_k, concat_v=concat_v)
        else:
            attn_output, attn_output_weights, k, v = multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, average_attn_weights=average_attn_weights,
                concat_k=concat_k, concat_v=concat_v)
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights, k, v
        else:
            return attn_output, attn_output_weights, k, v
    

def multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Optional[Tensor],
    in_proj_bias: Optional[Tensor],
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Optional[Tensor],
    training: bool = True,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
    use_separate_proj_weight: bool = False,
    q_proj_weight: Optional[Tensor] = None,
    k_proj_weight: Optional[Tensor] = None,
    v_proj_weight: Optional[Tensor] = None,
    static_k: Optional[Tensor] = None,
    static_v: Optional[Tensor] = None,
    average_attn_weights: bool = True,
    concat_k: Optional[Tensor] = None,
    concat_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        embed_dim_to_check: total dimension of the model.
        num_heads: parallel attention heads.
        in_proj_weight, in_proj_bias: input projection weight and bias.
        bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
        add_zero_attn: add a new batch of zeros to the key and
                       value sequences at dim=1.
        dropout_p: probability of an element to be zeroed.
        out_proj_weight, out_proj_bias: the output projection weight and bias.
        training: apply dropout if is ``True``.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. This is an binary mask. When the value is True,
            the corresponding value on the attention layer will be filled with -inf.
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.
        use_separate_proj_weight: the function accept the proj. weights for query, key,
            and value in different forms. If false, in_proj_weight will be used, which is
            a combination of q_proj_weight, k_proj_weight, v_proj_weight.
        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
        static_k, static_v: static key and value used for attention operators.
        average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
            Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
            when ``need_weights=True.``. Default: True
    Shape:
        Inputs:
        - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
          will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
          S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.
        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        Outputs:
        - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
          attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
          :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
          :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
          head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
    """
    tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
    if has_torch_function(tens_ops):
        return handle_torch_function(
            multi_head_attention_forward,
            tens_ops,
            query,
            key,
            value,
            embed_dim_to_check,
            num_heads,
            in_proj_weight,
            in_proj_bias,
            bias_k,
            bias_v,
            add_zero_attn,
            dropout_p,
            out_proj_weight,
            out_proj_bias,
            training=training,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attn_mask,
            use_separate_proj_weight=use_separate_proj_weight,
            q_proj_weight=q_proj_weight,
            k_proj_weight=k_proj_weight,
            v_proj_weight=v_proj_weight,
            static_k=static_k,
            static_v=static_v,
            average_attn_weights=average_attn_weights,
            concat_k=concat_k,
            concat_v=concat_v
        )

    is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)

    # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
    # is batched, run the computation and before returning squeeze the
    # batch dimension so that the output doesn't carry this temporary batch dimension.
    if not is_batched:
        # unsqueeze if the input is unbatched
        query = query.unsqueeze(1)
        key = key.unsqueeze(1)
        value = value.unsqueeze(1)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(0)

    # set up shape vars
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape
    if concat_k is not None:
        src_len = concat_k.shape[1]
    assert embed_dim == embed_dim_to_check, \
        f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
    if isinstance(embed_dim, torch.Tensor):
        # embed_dim can be a tensor when JIT tracing
        head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
    else:
        head_dim = embed_dim // num_heads
    assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
    if use_separate_proj_weight:
        # allow MHA to have different embedding dimensions when separate projection weights are used
        assert key.shape[:2] == value.shape[:2], \
            f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
    else:
        assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"

    #
    # compute in-projection
    #
    if not use_separate_proj_weight:
        assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
        q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
    else:
        assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
        assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
        assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
        if in_proj_bias is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = in_proj_bias.chunk(3)
        q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)

    # prep attention mask
    if attn_mask is not None:
        if attn_mask.dtype == torch.uint8:
            warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
            attn_mask = attn_mask.to(torch.bool)
        else:
            assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
                f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
        # ensure attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # prep key padding mask
    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
        warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
        key_padding_mask = key_padding_mask.to(torch.bool)

    # add bias along batch dimension (currently second)
    if bias_k is not None and bias_v is not None:
        assert static_k is None, "bias cannot be added to static key."
        assert static_v is None, "bias cannot be added to static value."
        k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
        v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))
    else:
        assert bias_k is None
        assert bias_v is None

    #
    # reshape q, k, v for multihead attention and make em batch first
    #
    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    if static_k is None:
        k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        assert static_k.size(0) == bsz * num_heads, \
            f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
        assert static_k.size(2) == head_dim, \
            f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
        k = static_k
    if static_v is None:
        v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        assert static_v.size(0) == bsz * num_heads, \
            f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
        assert static_v.size(2) == head_dim, \
            f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
        v = static_v

    # add zero attention along batch dimension (now first)
    if add_zero_attn:
        zero_attn_shape = (bsz * num_heads, 1, head_dim)
        k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
        v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))

    # update source sequence length after adjustments
    src_len = k.size(1)

    # merge key padding and attention masks
    if key_padding_mask is not None:
        assert key_padding_mask.shape == (bsz, src_len), \
            f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
        key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
            expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
        if attn_mask is None:
            attn_mask = key_padding_mask
        elif attn_mask.dtype == torch.bool:
            attn_mask = attn_mask.logical_or(key_padding_mask)
        else:
            attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))

    # convert mask to float
    if attn_mask is not None and attn_mask.dtype == torch.bool:
        new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
        new_attn_mask.masked_fill_(attn_mask, float("-inf"))
        attn_mask = new_attn_mask

    # adjust dropout probability
    if not training:
        dropout_p = 0.0

    #
    # (deep breath) calculate attention and out projection
    #
    
    B, Nt, E = q.shape
    q_scaled = q / math.sqrt(E)
    
    if concat_k is not None:
        k = torch.cat([concat_k[:, 1:].transpose(0, 1).contiguous().view(-1, k.shape[0], k.shape[2]).transpose(0, 1), k], axis=1)
    if concat_v is not None:
        v = torch.cat([concat_v[:, 1:].transpose(0, 1).contiguous().view(-1, v.shape[0], v.shape[2]).transpose(0, 1), v], axis=1)
        
    if attn_mask is not None:
        attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
    else:
        attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
    attn_output_weights = softmax(attn_output_weights, dim=-1)    
    
    if dropout_p > 0.0:
        attn_output_weights = dropout(attn_output_weights, p=dropout_p)    
    attn_output = torch.bmm(attn_output_weights, v)    

    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
    attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
    
    ret_k = k.transpose(0, 1).view(-1, bsz, num_heads, head_dim).transpose(0, 1)
    ret_v = v.transpose(0, 1).view(-1, bsz, num_heads, head_dim).transpose(0, 1)
    
    if need_weights:
        # optionally average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        if average_attn_weights:
            attn_output_weights = attn_output_weights.sum(dim=1) / num_heads
        
        if not is_batched:
            # squeeze the output if input was unbatched
            attn_output = attn_output.squeeze(1)
            attn_output_weights = attn_output_weights.squeeze(0)        
        return attn_output, attn_output_weights, ret_k, ret_v
    else:
        if not is_batched:
            # squeeze the output if input was unbatched
            attn_output = attn_output.squeeze(1) 
        return (attn_output, None, ret_k, ret_v)

class TransformerRNN(Module):
    def __init__(self, d_model: int = 512, nhead: int = 8, num_layers: int = 6,
                 mem_n = 16, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, device=None, dtype=None):
        
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerRNN, self).__init__()
        
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                activation, layer_norm_eps, False, False,
                                                **factory_kwargs)        
        self.layers = _get_clones(encoder_layer, num_layers)
        self.norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        
        self.num_layers = num_layers
        self.mem_n = mem_n
        self.d_model = d_model
        self.nhead = nhead        
        self.head_dim = self.d_model // self.nhead
        self._reset_parameters()    
        
        
    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""
        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)
    
    def forward(self, src: Tensor, core_state, notdone) -> Tensor:
        # Core state stored in the form of mask, (k, v), (k, v), ...
        # mask shape: (batch_size, mem_n)
        # key k and value v shape: (batch_size, mem_n, num_head, head_dim)
        
        src_mask = core_state[0][0]
        src_mask[~(notdone.bool()), :] = True
        src_mask[:, :-1] = src_mask[:, 1:].clone().detach()
        src_mask[:, -1] = False        
        new_core_state = [src_mask.unsqueeze(0)]        
        output = src.unsqueeze(0)
        
        bsz = src.shape[0]
        src_mask_ = src_mask.view(bsz, 1, 1, -1).broadcast_to(bsz, self.nhead, 1, -1).contiguous().view(bsz * self.nhead, 1, -1)
        ks = []
        vs = []

        for n, mod in enumerate(self.layers):
            output, new_k, new_v = mod(output, src_mask=src_mask_.detach(), concat_k=core_state[1][n], concat_v=core_state[2][n])
            ks.append(new_k.unsqueeze(0))
            vs.append(new_v.unsqueeze(0))
            
        output = self.norm(output)
        new_core_state.append(torch.cat(ks, dim=0))
        new_core_state.append(torch.cat(vs, dim=0))
        return output, new_core_state
    
    def init_state(self, bsz):
        core_state = (torch.ones(1, bsz, self.mem_n).bool(),
                      torch.zeros(self.num_layers, bsz, self.mem_n, self.nhead, self.head_dim),
                      torch.zeros(self.num_layers, bsz, self.mem_n, self.nhead, self.head_dim))
        return core_state

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 5000, concat=False):
        super().__init__()
        self.concat = concat

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe)

    def forward(self, x: Tensor, step: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
            step: int Tensor, shape [seq_len, batch_size]
        """
        if not self.concat:
            x = x + self.pe[step, :]
        else:
            x = torch.cat([x, self.pe[step, :]], dim=-1)
        return x       
    
def get_param(model, name):
    names, w = [], None
    for k, t in model.named_parameters():    
        names.append(k)
        if k == name:
            w = t
    return w, names    

In [None]:
# Testing if the results are the same
input = torch.rand(10, 6, 256)
num_layers = 1

model = TransformerRNN(d_model=256, num_layers=num_layers, mem_n=16, dropout=0)
core_state = model.init_state(6)
notdone = torch.ones(6).bool()

x = []
for i in range(10):
    input_i = input[i]
    x_i, core_state = model(input_i, core_state, notdone)    
    x.append(x_i)
    
model_out_1 = torch.concat(x, axis=0)

# Official implementation

from torch.nn.modules.transformer import TransformerEncoderLayer as TransformerEncoderLayer_
from torch.nn.modules.transformer import TransformerEncoder
src_mask = torch.triu(torch.ones(10, 10, dtype=bool), diagonal=1)
model_ = TransformerEncoder(TransformerEncoderLayer_(d_model=256, dropout=0, nhead=8), 
                            norm=LayerNorm(256, eps=1e-5),
                            num_layers=num_layers)
model_.load_state_dict(model.state_dict())
model_out_2 = model_(input, mask=src_mask)

print(torch.sum(torch.square(model_out_1 - model_out_2)))

In [None]:
# Testing if the results are the same
input = torch.rand(10, 6, 256)

model = TransformerRNN(d_model=256, num_layers=2, dropout=0)
core_state = model.init_state(6)

x = []
for i in range(10):
    input_i = input[i]
    x_i, core_state = model(input_i, core_state, torch.zeros(6).bool() if i == 4 else torch.ones(6).bool())    
    x.append(x_i)
    
model_out_1 = torch.concat(x, axis=0)
loss = torch.sum(model_out_1*10)
loss.backward()

# Official implementation

from torch.nn.modules.transformer import TransformerEncoderLayer as TransformerEncoderLayer_
from torch.nn.modules.transformer import TransformerEncoder
src_mask = torch.triu(torch.ones(10, 10, dtype=bool), diagonal=1)
src_mask[4:, :4] = True
model_ = TransformerEncoder(TransformerEncoderLayer_(d_model=256, dropout=0, nhead=8), 
                            norm=LayerNorm(256, eps=1e-5),
                            num_layers=2)
model_.load_state_dict(model.state_dict())
model_out_2 = model_(input, mask=src_mask)
loss = torch.sum(model_out_2*10)
loss.backward()

print(torch.sum(torch.square(model_out_1 - model_out_2)))
grad_1 = get_param(model, "layers.0.self_attn.out_proj.weight")[0].grad 
grad_2 = get_param(model_, "layers.0.self_attn.out_proj.weight")[0].grad
print(torch.sum(torch.square(grad_1 - grad_2)))

In [12]:
# Testing if the results are the same
input = torch.rand(10, 6, 256)
num_layers = 1

model = TransformerRNN(d_model=256, num_layers=num_layers, mem_n=1, dropout=0)
core_state = model.init_state(6)

x = []
for i in range(10):
    input_i = input[i]
    x_i, core_state = model(input_i, core_state, torch.zeros(6).bool() if i == 4 else torch.ones(6).bool())    
    x.append(x_i)
    
model_out_1 = torch.concat(x, axis=0)
loss = torch.sum(model_out_1*10)
loss.backward()

# Official implementation

from torch.nn.modules.transformer import TransformerEncoderLayer as TransformerEncoderLayer_
from torch.nn.modules.transformer import TransformerEncoder
src_mask = torch.ones(10, 10, dtype=bool)
src_mask.fill_diagonal_(0)
model_ = TransformerEncoder(TransformerEncoderLayer_(d_model=256, dropout=0, nhead=8), 
                            norm=LayerNorm(256, eps=1e-5),
                            num_layers=num_layers)
model_.load_state_dict(model.state_dict())
model_out_2 = model_(input, mask=src_mask)
loss = torch.sum(model_out_2*10)
loss.backward()

print(torch.sum(torch.square(model_out_1 - model_out_2)))
grad_1 = get_param(model, "layers.0.self_attn.out_proj.weight")[0].grad 
grad_2 = get_param(model_, "layers.0.self_attn.out_proj.weight")[0].grad
print(torch.sum(torch.square(grad_1 - grad_2)))

tensor(2.8266e-10, grad_fn=<SumBackward0>)
tensor(1.0471e-09)


In [None]:
env = wrap_pytorch(SokobanWrapper(gym.make("Sokoban-v0")))
_ = env.reset()

In [None]:
env = gym.make("Breakout-v4")
_ = env.reset()

In [None]:
ob, reward, done, info = env.step(0)
plt.imshow(ob.transpose(), interpolation='nearest')
plt.show()
print(ob.shape, reward, done)

In [158]:
from torchbeast.torchbeast.transformer_rnn import DepthSepConv, ConvTransformerRNN

class AtariNet(nn.Module):
    def __init__(self, observation_shape, num_actions, flags):
        

        super(AtariNet, self).__init__()
        self.observation_shape = observation_shape
        self.num_actions = num_actions        

        if flags.use_tran:
          self.num_im_actions = self.num_actions if flags.num_im_actions < num_actions else flags.num_im_actions
        else:
          self.num_im_actions = 0

        self.use_lstm = flags.use_lstm
        self.use_drc = flags.use_drc
        self.use_tran = flags.use_tran   
        self.tran_noskip = flags.tran_noskip   
        self.tran_t = flags.tran_t
        self.deep = flags.deep
        
        self.conv_hw = 8
        self.conv_out = 256 # dim for encoder output (flatten)
        self.d_model = 256 # number of dim for transformer model
        self.pos_encode = 64 # number of dim for positional embedding
        self.drc_hidden = 32
        
        if not self.deep:
            # Feature extraction.
            self.conv1 = nn.Conv2d(in_channels=self.observation_shape[0], out_channels=32, kernel_size=8, stride=4)        
            self.conv2 = nn.Conv2d(32, 32, kernel_size=4, stride=2)
            self.frame_conv = torch.nn.Sequential(self.conv1, nn.ReLU(), self.conv2, nn.ReLU())
            self.fc = nn.Linear(self.conv_hw * self.conv_hw * 32, self.conv_out)
        else:
            self.resnet = ResNet(BasicBlock, [2, 2, 2, 2], in_channel=self.observation_shape[0], num_classes=self.conv_out)

        # FC output size + one-hot of last action + last reward.
        env_input_size = self.conv_out + num_actions + 1   
        
        if self.use_tran:
            core_output_size = self.d_model
        else:
            core_output_size = env_input_size 

        print("core output size: ", core_output_size)

        if self.use_lstm:
            self.core = nn.LSTM(core_output_size, core_output_size, 2)            
        elif self.use_drc:
            self.core = ConvLSTM(h=self.conv_hw, w=self.conv_hw, input_dim=32 + num_actions + 1, hidden_dim=self.drc_hidden, 
                                kernel_size=3, num_layers=3, num_steps=3)
        elif self.use_tran:
            self.core = TransformerRNN(d_model=core_output_size, nhead=8, 
                num_layers=flags.tran_layer_n, dim_feedforward=flags.tran_ff_n, mem_n=flags.tran_mem_n, dropout=0.)
            self.pos = PositionalEncoding(d_model=self.pos_encode, max_len=500, concat=True)

        if self.use_tran and not self.tran_noskip:
            last_layer_size = core_output_size + env_input_size
        else:
            last_layer_size = core_output_size

        if self.use_tran:
            self.compress = nn.Linear(env_input_size + self.conv_out + self.num_im_actions + 1 + self.pos_encode, core_output_size)
            self.model_fc1 = nn.Linear(self.conv_out + self.num_im_actions, self.conv_out + 1)
            #self.model_fc2 = nn.Linear(self.conv_out + 1, self.conv_out + 1)  
            #self.model = nn.Sequential(self.model_fc1, nn.ReLU(), self.model_fc2, nn.ReLU())
            self.model = nn.Sequential(self.model_fc1, nn.ReLU())
            self.im_policy = nn.Linear(core_output_size, self.num_im_actions)          
            
        if self.use_drc:
            self.last_layer = nn.Linear(self.conv_hw * self.conv_hw * self.drc_hidden, 256)
            last_layer_size = 256

        self.policy = nn.Linear(last_layer_size, self.num_actions)        
        self.baseline = nn.Linear(last_layer_size, 1)

        self.reward_clipping = flags.reward_clipping

        print("model size: ", sum(p.numel() for p in self.parameters()))

    def initial_state(self, batch_size):
        state = ()
        if self.use_lstm:
            state = tuple(
                torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size)
                for _ in range(2)
            )
        elif self.use_drc:
            state = tuple(
                torch.zeros(3, batch_size, self.drc_hidden, self.conv_hw, self.conv_hw)
                for _ in range(2)
            )
        elif self.use_tran:
            state = self.core.init_state(batch_size) 
            state = state + (torch.zeros(1, batch_size, self.conv_out + self.num_im_actions + 1),)
        return state

    def forward(self, inputs, core_state=()):
        x = inputs["frame"]  # [T, B, C, H, W].
        T, B, *_ = x.shape
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = x.float() / 255.0

        if not self.deep:
            x = self.frame_conv(x)            
            if not self.use_drc: 
                x = x.view(T * B, -1)
                x = F.relu(self.fc(x))
        else:
            x = self.resnet(x)

        one_hot_last_action = F.one_hot(inputs["last_action"].view(T * B), self.num_actions).float()
        clipped_reward = torch.clamp(inputs["reward"], -self.reward_clipping, self.reward_clipping).view(T * B, 1)
        
        if self.use_drc:
            one_hot_last_action = add_hw(one_hot_last_action, self.conv_hw, self.conv_hw)
            clipped_reward = add_hw(clipped_reward, self.conv_hw, self.conv_hw)
        
        env_input_list = [x, one_hot_last_action, clipped_reward]
        env_input = torch.cat([x, one_hot_last_action, clipped_reward], dim=1)

        model_loss = torch.zeros(T, B).to(x.device)
        core_input = env_input

        if self.use_lstm or self.use_drc:

            core_input = core_input.view((T, B) + core_input.shape[1:])
            core_output_list = []
            notdone = (~inputs["done"]).float()
            for input, nd in zip(core_input.unbind(), notdone.unbind()):                
                # Reset core state to zero whenever an episode ended.
                # Make `done` broadcastable with (num_layers, B, hidden_size)
                # states:
                nd = nd.view(1, -1, 1) if self.use_lstm else nd.view(1, -1, 1, 1, 1)
                core_state = tuple(nd * s for s in core_state)
                output, core_state = self.core(input.unsqueeze(0), core_state)
                core_output_list.append(output)
            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)

        elif self.use_tran:

            core_input = core_input.view((T, B) + core_input.shape[1:])
            core_output_list = []
            notdone = (~inputs["done"]).bool()

            model_loss_list = [torch.zeros(B).to(x.device)]
            im_z = core_state[-1][0] # im_z shape: B, self.conv_out + self.num_im_actions + 1

            for n, (input, nd, ep_step, last_act) in enumerate(zip(core_input.unbind(), notdone.unbind(), 
                inputs["episode_step"], inputs["last_action"])):                
                # Input shape: B, env_input_size

                for m in range(self.tran_t):

                    # compute model loss if n > 0
                    if n > 0 and m == 0:
                        model_input = torch.concat([last_input[:, :self.conv_out+self.num_actions], 
                            torch.zeros(B, self.num_im_actions - self.num_actions).to(x.device)], dim=-1).detach()
                        model_output = self.model(model_input)
                        target_output = torch.concat([input[:, :self.conv_out], 
                            input[:, [self.conv_out+self.num_actions]]], dim=-1)
                        diff = torch.sum((target_output - model_output) ** 2, dim=-1)
                        mask = torch.logical_or(ep_step == 0, last_act == 0)
                        mask_diff = torch.where(mask, torch.zeros(B).to(x.device), diff)
                        model_loss_list.append(mask_diff + 0.0001 * torch.sum(model_output ** 2, dim=-1))

                    # re-set im_z to real_z if real aciton is taken
                    if m == 0:
                        reset = torch.logical_or(ep_step == 0, last_act != 0)
                        reset_z = torch.concat([input[:, :self.conv_out+self.num_actions], 
                            torch.zeros(B, self.num_im_actions - self.num_actions).to(x.device), 
                            input[:, [self.conv_out+self.num_actions]]], axis=-1)
                        im_z = torch.where(reset.unsqueeze(-1), reset_z, im_z)

                    input_p = torch.cat([input, im_z], axis=-1)
                    input_p = self.pos(input_p, ep_step.long() % 5000)            
                    input_p = self.compress(input_p)
                    
                    nd = nd.view(-1)              
                    output, core_state = self.core(input_p, core_state, nd) # output shape: 1, B, core_output_size                    

                    # recompute next im_z
                    im_policy_logits = self.im_policy(output[0])
                    soft_im_action = F.softmax(im_policy_logits, dim=1)
                    model_input = torch.cat([im_z[:, :self.conv_out], soft_im_action], axis=-1)
                    model_output = self.model(model_input) # model outputs (s, r)
                    im_z = torch.cat([model_output[:, :-1], soft_im_action, model_output[:, [-1]]], axis=-1)

                last_input = input                
                core_output_list.append(output)

            core_output = torch.cat(core_output_list)            
            core_output = torch.flatten(core_output, 0, 1)

            if not self.tran_noskip:
                core_output = torch.cat([core_output, env_input], axis=-1)

            if torch.any(torch.isnan(core_output)):
                print("T", T, " B", B)
                print("core_output: ", core_output)
                print("env_input: ", env_input)
                print("model_output: ", model_output)
                print("target_output: ", target_output)                
                raise Exception("nan detected")

            model_loss = torch.cat(model_loss_list)            
        else:
            core_output = env_input
            core_state = tuple()   
            
        if self.use_drc:
            core_output = core_output.view(T * B, -1)
            core_output = F.relu(self.last_layer(core_output))          
        
        policy_logits = self.policy(core_output)
        baseline = self.baseline(core_output)

        if self.use_tran: core_state.append(im_z.unsqueeze(0))

        action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)        

        reg_loss = (1e-3 * torch.sum(policy_logits**2, dim=-1) / 2 + 
                    1e-5 * torch.sum(core_output**2, dim=-1) / 2)
        reg_loss = reg_loss.view(T, B)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        return (
            dict(policy_logits=policy_logits, baseline=baseline, 
                 action=action, reg_loss=reg_loss, model_loss=model_loss),
            core_state,
        )

In [157]:
flags.use_tran = False
flags.use_drc = True
net = AtariNet(observation_shape=(3,84,84), num_actions=4, flags=flags)    
#checkpoint = torch.load("model.tar", map_location="cpu")
#net.load_state_dict(checkpoint["model_state_dict"])    
print("tot params: ", sum(p.numel() for p in net.parameters()))

from matplotlib import pyplot as plt
import gym
import gym_sokoban
import numpy as np
from torchbeast.torchbeast.atari_wrappers import *
from gym.wrappers import TimeLimit
from torchbeast.torchbeast.core import environment

from torchbeast.torchbeast import atari_wrappers
env = atari_wrappers.wrap_pytorch(atari_wrappers.SokobanWrapper(gym.make("Sokoban-v0")))
env = environment.Environment(env)
obs = env.initial()

for i in range(20):
    core_state = net.initial_state(batch_size=1)
    net_out_, core_state = net(obs, core_state)
    obs_ = env.step(net_out_["action"][-1])
    obs_["last_action"] = obs_["last_action"].unsqueeze(0)
    obs = {k: torch.concat([obs[k], obs_[k]]) for k in obs.keys()}
    if i > 0:
        #for k in net_out.keys(): print(k, net_out[k].shape, net_out_[k].shape)
        net_out = {k: torch.concat([net_out[k], net_out_[k][[-1]]]) for k in net_out.keys()}
    else:
        net_out = net_out_    

core output size:  261
model size:  1533285
tot params:  1533285
tensor([[ 0.0794,  0.1469,  0.1494,  ...,  0.0384,  0.0332, -0.0058]],
       grad_fn=<ViewBackward0>)
tensor([[ 0.0794,  0.1469,  0.1494,  ...,  0.0384,  0.0332, -0.0058],
        [ 0.5257,  0.5258,  0.5329,  ...,  0.5310,  0.5098, -0.2731]],
       grad_fn=<ViewBackward0>)
tensor([[ 0.0794,  0.1469,  0.1494,  ...,  0.0384,  0.0332, -0.0058],
        [ 0.5257,  0.5258,  0.5329,  ...,  0.5310,  0.5098, -0.2731],
        [ 0.2856,  0.5210,  0.5043,  ...,  0.5355,  0.5080, -0.5599]],
       grad_fn=<ViewBackward0>)
tensor([[ 0.0794,  0.1469,  0.1494,  ...,  0.0384,  0.0332, -0.0058],
        [ 0.5257,  0.5258,  0.5329,  ...,  0.5310,  0.5098, -0.2731],
        [ 0.2856,  0.5210,  0.5043,  ...,  0.5355,  0.5080, -0.5599],
        [ 0.5061,  0.5047,  0.5156,  ...,  0.5347,  0.5180, -0.5382]],
       grad_fn=<ViewBackward0>)
tensor([[ 0.0794,  0.1469,  0.1494,  ...,  0.0384,  0.0332, -0.0058],
        [ 0.5257,  0.5258,  0.532

In [117]:
for k, v in net.named_parameters():
    print(k, v.numel())

conv1.weight 6144
conv1.bias 32
conv2.weight 16384
conv2.bias 32
fc.weight 524288
fc.bias 256
core.cell_list.0.conv.weight 147456
core.cell_list.0.conv.bias 128
core.cell_list.1.conv.weight 147456
core.cell_list.1.conv.bias 128
core.cell_list.2.conv.weight 147456
core.cell_list.2.conv.bias 128
core.proj_list.0.weight 64
core.proj_list.0.bias 32
core.proj_list.1.weight 64
core.proj_list.1.bias 32
core.proj_list.2.weight 64
core.proj_list.2.bias 32
last_layer.weight 524288
last_layer.bias 256
policy.weight 1024
policy.bias 4
baseline.weight 256
baseline.bias 1


In [None]:
for k, v in net_out.items():
    print(k , v)

In [None]:
optimizer = torch.optim.Adam(net.parameters())
optimizer.zero_grad()
model_loss = torch.sum(net_out["model_loss"])
print(model_loss)

In [112]:
import argparse
# yapf: disable
parser = argparse.ArgumentParser(description="PyTorch Scalable Agent")

parser.add_argument("--env", type=str, default="PongNoFrameskip-v4",
                    help="Gym environment.")
parser.add_argument("--mode", default="train",
                    choices=["train", "test", "test_render"],
                    help="Training or test mode.")
parser.add_argument("--xpid", default=None,
                    help="Experiment id (default: None).")

# Training settings.
parser.add_argument("--disable_checkpoint", action="store_true",
                    help="Disable saving checkpoint.")
parser.add_argument("--savedir", default="~/logs/torchbeast",
                    help="Root dir where experiment data will be saved.")
parser.add_argument("--num_actors", default=48, type=int, metavar="N",
                    help="Number of actors (default: 48).")
parser.add_argument("--total_steps", default=100000000, type=int, metavar="T",
                    help="Total environment steps to train for.")
parser.add_argument("--batch_size", default=32, type=int, metavar="B",
                    help="Learner batch size.")
parser.add_argument("--unroll_length", default=20, type=int, metavar="T",
                    help="The unroll length (time dimension).")
parser.add_argument("--num_buffers", default=None, type=int,
                    metavar="N", help="Number of shared-memory buffers.")
parser.add_argument("--num_learner_threads", "--num_threads", default=1, type=int,
                    metavar="N", help="Number learner threads.")
parser.add_argument("--disable_cuda", action="store_true",
                    help="Disable CUDA.")
parser.add_argument("--use_lstm", action="store_true",
                    help="Use LSTM in agent model.")

parser.add_argument("--use_DRC", action="store_true",
                    help="Use DRC in agent model.")

parser.add_argument("--use_tran", action="store_true",
                    help="Use transformer in agent model.")
parser.add_argument("--tran_mem_n", default=8, type=int, metavar="N",
                    help="Size of transformer memory.")
parser.add_argument("--tran_layer_n", default=2, type=int, metavar="N",
                    help="Number of transformer layer.")
parser.add_argument("--tran_ff_n", default=256, type=int, metavar="N",
                    help="Size of transformer ff .")
parser.add_argument("--tran_noskip", action="store_true",
                    help="Whether to enable noskip.")
parser.add_argument("--tran_t", default=1, type=int, metavar="T",
                    help="Number of recurrent step for transformer.")

parser.add_argument("--num_im_actions", default=0, type=int, 
                    metavar="N", help="Number of imagainary action; 0 for no imagagination.")
parser.add_argument("--deep", action="store_true",
                    help="Use ResNet 18 to process input.")

# Loss settings.
parser.add_argument("--entropy_cost", default=0.01,
                    type=float, help="Entropy cost/multiplier.")
parser.add_argument("--baseline_cost", default=0.5,
                    type=float, help="Baseline cost/multiplier.")
parser.add_argument("--model_cost", default=1,
                    type=float, help="Model cost/multiplier.")
parser.add_argument("--discounting", default=0.97,
                    type=float, help="Discounting factor.")
parser.add_argument("--lamb", default=0.97,
                    type=float, help="Lambda when computing trace.")
parser.add_argument("--reward_clipping", default=10, type=int, 
                    metavar="N", help="Reward clipping.")

# Optimizer settings.
parser.add_argument("--learning_rate", default=0.0004,
                    type=float, metavar="LR", help="Learning rate.")
parser.add_argument("--use_adam", action="store_false",
                    help="Use Aadm optimizer or not.")
parser.add_argument("--alpha", default=0.99, type=float,
                    help="RMSProp smoothing constant.")
parser.add_argument("--momentum", default=0, type=float,
                    help="RMSProp momentum.")
parser.add_argument("--epsilon", default=0.01, type=float,
                    help="RMSProp epsilon.")
parser.add_argument("--grad_norm_clipping", default=40.0, type=float,
                    help="Global gradient norm clip.")
# yapf: enable

flags = parser.parse_args("--env Sokoban-v0".split())

In [107]:
# DRC (D, N)

import torch.nn as nn
import torch

class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size):

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size // 2

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels= 4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

class ConvLSTM(nn.Module):

    def __init__(self, h, w, input_dim, hidden_dim, kernel_size, num_layers, num_steps):
        super(ConvLSTM, self).__init__()
        
        self.h = h
        self.w = w
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.num_steps = num_steps

        cell_list = []
        proj_list = []
        
        for i in range(0, self.num_layers):
            cell_list.append(ConvLSTMCell(input_dim=input_dim+hidden_dim*2,
                                          hidden_dim=self.hidden_dim,
                                          kernel_size=self.kernel_size,))
            proj_list.append(torch.nn.Conv2d(hidden_dim, hidden_dim, (2,1), groups=hidden_dim))

        self.cell_list = nn.ModuleList(cell_list)
        self.proj_list = nn.ModuleList(proj_list)
    
    def init_state(self, bsz):
        core_state = (torch.zeros(self.num_layers, bsz, self.hidden_dim, self.h, self.w),
                      torch.zeros(self.num_layers, bsz, self.hidden_dim, self.h, self.w))
        return core_state
    
    def forward(self, x, core_state):        
        t, b, c, h, w = x.shape
        out = core_state[1][-1]
        
        core_out = []
        for input in x:
            for _ in range(self.num_steps):
                new_core_state = ([], [])
                for n, (cell, proj) in enumerate(zip(self.cell_list, self.proj_list)):
                    cell_input = torch.concat([input, out, self.proj_max_mean(out, proj)], dim=1)
                    state = (core_state[0][n], core_state[1][n])
                    out, state =  cell(cell_input, state)
                    new_core_state[0].append(out)
                    new_core_state[1].append(state)                
                core_state = new_core_state
            core_out.append(out.unsqueeze(0))
        
        core_out = torch.cat(core_out)
        core_state = tuple(torch.cat([u.unsqueeze(0) for u in v]) for v in core_state)
                
        return core_out, core_state

    def proj_max_mean(self, out, linear_proj):
        out_mean = torch.mean(out, dim=(-1,-2), keepdim=True)
        out_max = torch.max(torch.max(out, dim=-1, keepdim=True)[0], dim=-2, keepdim=True)[0]
        proj_in = torch.cat([out_mean, out_max], dim=-2)
        out_sum = linear_proj(proj_in).broadcast_to(out.shape)
        return out_sum

h = 4
w = 4
input_dim = 8 
hidden_dim = 6
kernel_size = 3
num_layers = 3
num_steps = 3
bsz = 5
T = 2

net = ConvLSTM(h=h, w=w, input_dim=input_dim, hidden_dim=hidden_dim, 
               kernel_size=kernel_size, num_layers=num_layers, num_steps=num_steps)
input = torch.rand(T, bsz, input_dim, h, w)
core_state = net.init_state(bsz)
out_1 = net(input, core_state)[0]

core_state = net.init_state(bsz)
out_2 = []
for t in range(T):
    out, core_state = net(input[[t]], core_state)
    out_2.append(out)
out_2 = torch.cat(out_2)    
print(torch.sum((out_1-out_2)**2))

tensor(6.2132e-07, grad_fn=<SumBackward0>)


In [50]:
torch.max(out, dim=-1, keepdim=True)

torch.return_types.max(
values=tensor([[[[[0.9908],
           [0.9832],
           [0.8082],
           [0.6583]],

          [[0.8729],
           [0.8853],
           [0.7920],
           [0.5834]],

          [[0.3262],
           [0.8025],
           [0.6625],
           [0.9462]]],


         [[[0.5029],
           [0.8083],
           [0.8200],
           [0.5762]],

          [[0.9971],
           [0.6909],
           [0.8239],
           [0.7388]],

          [[0.7379],
           [0.9413],
           [0.9169],
           [0.9335]]],


         [[[0.8956],
           [0.9918],
           [0.6687],
           [0.9968]],

          [[0.9054],
           [0.6694],
           [0.7598],
           [0.9490]],

          [[0.5472],
           [0.8139],
           [0.8856],
           [0.8371]]],


         [[[0.9506],
           [0.5120],
           [0.8993],
           [0.9453]],

          [[0.9173],
           [0.8397],
           [0.9131],
           [0.6695]],

          [[0.828

In [69]:
out = torch.rand(2, 3, 4, 4)
out_c = out.shape[-3]

linear_proj = torch.nn.Conv2d(out_c, out_c, (2,1), groups=out_c) 
out_mean = torch.mean(out, dim=(-1,-2), keepdim=True)
out_max = torch.max(torch.max(out, dim=-1, keepdim=True)[0], dim=-2, keepdim=True)[0]
proj_in = torch.cat([out_mean, out_max], dim=-2)
out_sum = linear_proj(proj_in).broadcast_to(out.shape)

print(out_sum.shape)

torch.Size([2, 3, 4, 4])


In [70]:
print(out_sum)

tensor([[[[ 0.0168,  0.0168,  0.0168,  0.0168],
          [ 0.0168,  0.0168,  0.0168,  0.0168],
          [ 0.0168,  0.0168,  0.0168,  0.0168],
          [ 0.0168,  0.0168,  0.0168,  0.0168]],

         [[ 1.0931,  1.0931,  1.0931,  1.0931],
          [ 1.0931,  1.0931,  1.0931,  1.0931],
          [ 1.0931,  1.0931,  1.0931,  1.0931],
          [ 1.0931,  1.0931,  1.0931,  1.0931]],

         [[-0.3826, -0.3826, -0.3826, -0.3826],
          [-0.3826, -0.3826, -0.3826, -0.3826],
          [-0.3826, -0.3826, -0.3826, -0.3826],
          [-0.3826, -0.3826, -0.3826, -0.3826]]],


        [[[-0.0571, -0.0571, -0.0571, -0.0571],
          [-0.0571, -0.0571, -0.0571, -0.0571],
          [-0.0571, -0.0571, -0.0571, -0.0571],
          [-0.0571, -0.0571, -0.0571, -0.0571]],

         [[ 1.1105,  1.1105,  1.1105,  1.1105],
          [ 1.1105,  1.1105,  1.1105,  1.1105],
          [ 1.1105,  1.1105,  1.1105,  1.1105],
          [ 1.1105,  1.1105,  1.1105,  1.1105]],

         [[-0.3897, -0.3897,

In [68]:
out_sum

torch.Size([2, 3, 1, 1, 4, 4])

In [56]:
for i in linear_proj.parameters():
    print(i)

Parameter containing:
tensor([[[[ 0.0842],
          [ 0.3822]]],


        [[[-0.1097],
          [ 0.3673]]],


        [[[-0.0785],
          [-0.1090]]]], requires_grad=True)
Parameter containing:
tensor([0.4775, 0.6595, 0.4080], requires_grad=True)
