In [28]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadSelfAttention, self).__init__()
        assert embed_size % heads == 0, "Embedding size needs to be divisible by heads"
        self.heads = heads
        self.embed_size = embed_size
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.keys = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.queries = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size)

    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split embedding into self.heads different pieces
        values = values.view(N, value_len, self.heads, self.head_dim)
        keys = keys.view(N, key_len, self.heads, self.head_dim)
        queries = queries.view(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # (N, heads, query_len, key_len)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.embed_size
        )

        out = self.fc_out(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, embed_size, ff_hidden):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_hidden)
        self.fc2 = nn.Linear(ff_hidden, embed_size)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadSelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = FeedForward(embed_size, ff_hidden)
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, embed_size)
        self.encoding.requires_grad = False  # We don't want to update positional encoding during training

        pos = torch.arange(0, max_len).unsqueeze(1)
        _2i = torch.arange(0, embed_size, step=2)

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / embed_size)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / embed_size)))

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.encoding[:seq_len, :].to(x.device)


class Transformer(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden, num_layers, vocab_size, max_len=100, dropout=0.5):
        super(Transformer, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position = PositionalEncoding(embed_size, max_len)
        self.layers = nn.ModuleList(
            [TransformerBlock(embed_size, heads, ff_hidden, dropout) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        for layer in self.layers:
            out = layer(x, x, x, mask)
        out = self.fc_out(out)
        return out


In [39]:
# Define parameters for the Transformer
embed_size = 6           # Embedding dimension for each token
heads = 2                # Number of attention heads (embedding size should be divisible by this)
ff_hidden = 24           # Dimension of the feed-forward layer (you can adjust this)
num_layers = 4           # Number of transformer layers
vocab_size = 2       # Vocabulary size (can be any appropriate number for your use case)
max_len = 100            # Maximum sequence length (100 tokens in your case)
dropout = 0.1            # Dropout rate for regularization

# Create a Transformer object
transformer = Transformer(
    embed_size=embed_size,
    heads=heads,
    ff_hidden=ff_hidden,
    num_layers=num_layers,
    vocab_size=vocab_size,
    max_len=max_len,
    dropout=dropout
).cuda()

# Print the model summary (or model architecture)
print(transformer)


# Create a random input tensor with token indices between 0 and vocab_size-1
random_input = torch.rand( 1, 16, 6).float().cuda()

# Print the shape of the input tensor
print(random_input.shape)  # Should be: (32, 100, 6)

# Example of passing the random input through the transformer model
mask = None  # No mask applied here, but can be added if needed
output = transformer(random_input, mask)

# Print the output shape
print(output.shape)  # Output should be: (32, 100, 1000)



Transformer(
  (word_embedding): Embedding(2, 6)
  (position): PositionalEncoding()
  (layers): ModuleList(
    (0-3): 4 x TransformerBlock(
      (attention): MultiHeadSelfAttention(
        (values): Linear(in_features=3, out_features=6, bias=False)
        (keys): Linear(in_features=3, out_features=6, bias=False)
        (queries): Linear(in_features=3, out_features=6, bias=False)
        (fc_out): Linear(in_features=6, out_features=6, bias=True)
      )
      (norm1): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
      (feed_forward): FeedForward(
        (fc1): Linear(in_features=6, out_features=24, bias=True)
        (fc2): Linear(in_features=24, out_features=6, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (fc_out): Linear(in_features=6, out_features=2, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
torch.Size([1, 16, 6])
torch.Size([1, 16, 2])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor
from stable_baselines3.common.distributions import (
    DiagGaussianDistribution,
    CategoricalDistribution,
    MultiCategoricalDistribution,
    BernoulliDistribution,
    StateDependentNoiseDistribution,
    make_proba_distribution,
)
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import numpy as np
from functools import partial
import warnings
import collections

# Transformer Components

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadSelfAttention, self).__init__()
        assert embed_size % heads == 0, "Embedding size needs to be divisible by heads"
        self.heads = heads
        self.embed_size = embed_size
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.keys = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.queries = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size)

    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        values = values.view(N, value_len, self.heads, self.head_dim)
        keys = keys.view(N, key_len, self.heads, self.head_dim)
        queries = queries.view(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.embed_size)
        out = self.fc_out(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, embed_size, ff_hidden):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_hidden)
        self.fc2 = nn.Linear(ff_hidden, embed_size)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadSelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = FeedForward(embed_size, ff_hidden)
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, embed_size)
        self.encoding.requires_grad = False

        pos = torch.arange(0, max_len).unsqueeze(1)
        _2i = torch.arange(0, embed_size, step=2)

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / embed_size)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / embed_size)))

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.encoding[:seq_len, :].to(x.device)


class Transformer(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden, num_layers, vocab_size, max_len=100, dropout=0.5):
        super(Transformer, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position = PositionalEncoding(embed_size, max_len)
        self.layers = nn.ModuleList(
            [TransformerBlock(embed_size, heads, ff_hidden, dropout) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        for layer in self.layers:
            out = layer(x, x, x, mask)
        out = self.fc_out(out)
        return out


# ActorCriticPolicy with Transformer

class ActorCriticTransformerPolicy(BasePolicy):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
        activation_fn: Type[nn.Module] = nn.Tanh,
        ortho_init: bool = True,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        share_features_extractor: bool = True,
        normalize_images: bool = True,
        optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        embed_size: int = 128,
        heads: int = 8,
        ff_hidden: int = 512,
        num_layers: int = 3,
        vocab_size: int = 100,
        max_len: int = 100,
        dropout: float = 0.5,
    ):
        super().__init__(
            observation_space,
            action_space,
            features_extractor_class,
            features_extractor_kwargs,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            squash_output=squash_output,
            normalize_images=normalize_images,
        )

        # Initialize the Transformer policy and value networks
        self.transformer = Transformer(embed_size, heads, ff_hidden, num_layers, vocab_size, max_len, dropout)

        # Initialize action distribution
        self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=None)

        self.optimizer = optimizer_class(self.parameters(), lr=lr_schedule(1), **optimizer_kwargs)

    def _build(self, lr_schedule: Schedule) -> None:
        self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)

    def forward(self, obs: torch.Tensor, mask=None, deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        features = self.extract_features(obs)
        latent = self.transformer(features, mask)

        # Split into policy and value latent vectors
        latent_pi, latent_vf = latent, latent

        values = self.value_net(latent_vf)
        distribution = self._get_action_dist_from_latent(latent_pi)
        actions = distribution.get_actions(deterministic=deterministic)
        log_prob = distribution.log_prob(actions)

        actions = actions.reshape((-1, *self.action_space.shape))
        return actions, values, log_prob

    def extract_features(self, obs: torch.Tensor) -> torch.Tensor:
        return obs  # You might want to replace this with a feature extractor if needed

    def _get_action_dist_from_latent(self, latent_pi: torch.Tensor) -> torch.distributions.Distribution:
        mean_actions = self.action_net(latent_pi)
        return self.action_dist.proba_distribution(mean_actions, self.log_std)

    def predict_values(self, obs: torch.Tensor) -> torch.Tensor:
        features = self.extract_features(obs)
        latent_vf = self.transformer(features, None)
        return self.value_net(latent_vf)

