In [1]:
from typing import Optional, Any, Union, Callable
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, Transformer
from torch.utils.data import dataset
from memory.environment.generator import OQAGenerator
from memory.environment.gym import MemoryEnv
from memory import EpisodicMemory, SemanticMemory

class Embeddings(nn.Module):
    """Turn a given memory table into learanble embeddings."""

    def __init__(
        self,
        generator_params: dict,
        embedding_dim: int,
        num_rows: int,
        num_cols: int,
        special_tokens: dict = {"<pad>": 0, "<mask>": 1},
        device: str = "cpu",
    ) -> None:
        """Initialize an Embeddings object.

        Args
        ----
        generator_params: OQAGenerator parameters
        embedding_dim: embedding dimension (e.g., 4)
        num_rows: number of rows in the table.
        num_cols: number of columns in the table.
        special_tokens: dict = {"<pad>": 0, "<mask>": 1}
        device: "cpu" or "cuda"
        sinusoidal_pos_embds: True if you want to use sinusoidal, False if you want to
            use learnable.

        """
        super().__init__()
        oqag = OQAGenerator(**generator_params)
        self.embedding_dim = embedding_dim
        self.num_rows = num_rows
        self.num_cols = num_cols
        self.heads = nn.Embedding(len(oqag.heads), self.embedding_dim)
        self.head2num = {head: idx for idx, head in enumerate(oqag.heads)}

        self.relations = nn.Embedding(len(oqag.relations), self.embedding_dim)
        self.relation2num = {
            relation: idx for idx, relation in enumerate(oqag.relations)
        }

        self.tails = nn.Embedding(len(oqag.tails), self.embedding_dim)
        self.tail2num = {tail: idx for idx, tail in enumerate(oqag.tails)}

        self.names = nn.Embedding(len(oqag.names), self.embedding_dim)
        self.name2num = {name: idx for idx, name in enumerate(oqag.names)}

        self.specials = nn.Embedding(
            len(special_tokens), self.embedding_dim, padding_idx=special_tokens["<pad>"]
        )
        self.special2num = {token: num for token, num in special_tokens.items()}


        self.pos_encoder = PositionalEncoding(embedding_dim*3)
        self.device = torch.device(device)

    def episodic2numerics(self, M_e: EpisodicMemory) -> torch.Tensor:
        """Convert the EpisodicMemory object into numerical values.

        Args
        ----
        M_e: EpisodicMemory object.

        Returns
        -------
        Episodic memory table whose shape is (num_memories, embedding_dim*3)

        """
        table = []
        for mem in M_e.entries:
            head = mem[0]
            relation = mem[1]
            tail = mem[2]
            # timestamps are not used at the moment since the memories are ordered by
            # time anyways.
            # timestamp = mem[3]

            name1, head = M_e.split_name_entity(head)
            name2, tail = M_e.split_name_entity(tail)

            assert name1 == name2

            head = self.names(
                torch.tensor(self.name2num[name1], device=self.device)
            ) + self.heads(torch.tensor(self.head2num[head], device=self.device))

            relation = self.relations(
                torch.tensor(self.relation2num[relation], device=self.device)
            )

            tail = self.names(
                torch.tensor(self.name2num[name2], device=self.device)
            ) + self.tails(torch.tensor(self.tail2num[tail], device=self.device))

            table.append(torch.cat([head, relation, tail]))

        table = torch.stack(table)

        return table

    def eq2numerics(self, question: list) -> torch.Tensor:
        """Convert the question into numerical values.

        Args
        ----
        question: [head, relation]. E.g., [Tae's laptop, AtLocation]

        Returns
        -------
        Episodic question whose shape is (1, embedding_dim*3)

        """
        head = question[0]
        name, head = EpisodicMemory.split_name_entity(head)
        relation = question[1]

        head = self.names(
            torch.tensor(self.name2num[name], device=self.device)
        ) + self.heads(torch.tensor(self.head2num[head], device=self.device))

        relation = self.relations(
            torch.tensor(self.relation2num[relation], device=self.device)
        )

        tail = self.names(
            torch.tensor(self.name2num[name], device=self.device)
        ) + self.specials(torch.tensor(self.special2num["<mask>"], device=self.device))

        table = torch.cat([head, relation, tail]).unsqueeze(0)

        return table

    def semantic2numerics(self, M_s: SemanticMemory) -> torch.Tensor:
        """Convert the SemanticMemory object into numerical values.

        Args
        ----
        M_s: SemanticMemory object.

        Returns
        -------
        Semantic memory table whose shape is (num_memories, embedding_dim*3)

        """
        table = []
        for mem in M_s.entries:
            head = mem[0]
            relation = mem[1]
            tail = mem[2]

            # num_gen are not used at the moment since the memories are ordered by
            # num_gen anyways.
            # num_gen = mem[3]

            head = self.heads(torch.tensor(self.head2num[head], device=self.device))

            relation = self.relations(
                torch.tensor(self.relation2num[relation], device=self.device)
            )

            tail = self.tails(torch.tensor(self.tail2num[tail], device=self.device))

            table.append(torch.cat([head, relation, tail]))

        table = torch.stack(table)

        return table

    def sq2numerics(self, question: list) -> torch.Tensor:
        """Convert the question into numerical values.

        Args
        ----
        question: [head, relation]. E.g., [Tae's laptop, AtLocation]

        Returns
        -------
        Semantic question whose shape is (1, embedding_dim*3)

        """
        head = SemanticMemory.remove_name(question[0])
        relation = question[1]

        head = self.heads(torch.tensor(self.head2num[head], device=self.device))

        relation = self.relations(
            torch.tensor(self.relation2num[relation], device=self.device)
        )

        tail = self.specials(
            torch.tensor(self.special2num["<mask>"], device=self.device)
        )

        table = torch.cat([head, relation, tail]).unsqueeze(0)

        return table

    def forward(
        self, M_e: EpisodicMemory, M_s: SemanticMemory, question: list, policy_type: str
    ):
        """One forward pass.

        Args
        ----
        M_e: EpisodicMemory object.
        M_s: SemanticMemory object.
        question: [head, relation]. E.g., [Tae's laptop, AtLocation]
        policy_type: one of the below:
            episodic_memory_manage
            episodic_question_answer
            semantic_memory_manage
            semantic_question_answer
            episodic_to_semantic
            episodic_semantic_question_answer

        """
        x = torch.zeros(
            self.num_rows, self.num_cols, dtype=torch.float32, device=self.device
        )

        if policy_type == "episodic_memory_manage":
            table = self.episodic2numerics(M_e)
            x[: table.shape[0], :] = table

        elif policy_type == "episodic_question_answer":
            table = self.episodic2numerics(M_e)
            x[: table.shape[0], :] = table

            table = self.eq2numerics(question)
            x[-1, :] = table

        elif policy_type == "semantic_memory_manage":
            table = self.semantic2numerics(M_s)
            x[: table.shape[0], :] = table

        elif policy_type == "semantic_question_answer":
            table = self.semantic2numerics(M_s)
            x[: table.shape[0], :] = table

            table = self.sq2numerics(question)
            x[-1, :] = table

        elif policy_type == "episodic_semantic_question_answer":
            table = self.episodic2numerics(M_e)
            if len(table) > 0:
                x[: table.shape[0], :] = table

            table = self.semantic2numerics(M_s)
            if len(table) > 0:
                x[M_e.capacity : M_e.capacity + table.shape[0], :] = table

            table = self.eq2numerics(question)
            x[-1, :] = table

        else:
            raise ValueError

        for idx in range(x.shape[0]):
            x[idx] += self.positional_embeddings(torch.tensor(idx, device=self.device))

        # x = x[torch.randperm(x.shape[0])]
        return x


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)


transformer_params = {
    "d_model": 4,
    "nhead": 1,
    "dropout": 0.1,
    "num_encoder_layers": 1,
    "num_decoder_layers": 1,
    "dim_feedforward": 16,
}

transformer = Transformer(**transformer_params)

src = torch.rand((5, 8, 4))
tgt = torch.rand((1, 8, 4))
out = transformer(src, tgt)


In [5]:
src.shape

torch.Size([5, 8, 4])

In [None]:
from typing import Optional, Any, Union, Callable
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
from memory.environment.generator import OQAGenerator
from memory.environment.gym import MemoryEnv

transformer_params = {
    "d_model": 4,
    "nhead": 1,
    "dropout": 0.1,
    "num_encoder_layers": 1,
    "num_decoder_layers": 1,
    "dim_feedforward": 16,
}


class MemoryTransformer(nn.Transformer):
    def __init__(self, transformer_params: dict, generator_params: dict) -> None:

        self.transformer_params = transformer_params
        super().__init__(**self.transformer_params)
        self.generator_params = generator_params
        self.d_model = transformer_params["d_model"]
        self.dropout = transformer_params["dropout"]
        self.pos_encoder = PositionalEncoding(self.d_model, self.dropout)
        oqag = OQAGenerator(**self.generator_params)

        self.vocab = (
            ["<pad>", "<mask>"]
            + sorted(oqag.heads)
            + sorted(oqag.relations)
            + sorted(oqag.tails)
        )
        self.ntoken = len(self.vocab)
        self.tokenizer = nn.Embedding(self.ntoken, self.d_model)

    def forward(
        self,
        src: Tensor,
        tgt: Tensor,
        src_mask: Optional[Tensor] = None,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
    ) -> Tensor:
        r"""Take in and process masked source/target sequences.
        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
            src_mask: the additive mask for the src sequence (optional).
            tgt_mask: the additive mask for the tgt sequence (optional).
            memory_mask: the additive mask for the encoder output (optional).
            src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
            tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
            memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
        Shape:
            - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
              `(N, S, E)` if `batch_first=True`.
            - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
              `(N, T, E)` if `batch_first=True`.
            - src_mask: :math:`(S, S)`.
            - tgt_mask: :math:`(T, T)`.
            - memory_mask: :math:`(T, S)`.
            - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
            - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
            - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
            Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
            is provided, it will be added to the attention weight.
            [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
            the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
            positions will be unchanged. If a BoolTensor is provided, the positions with the
            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
            - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
              `(N, T, E)` if `batch_first=True`.
            Note: Due to the multi-head attention architecture in the transformer model,
            the output sequence length of a transformer is same as the input sequence
            (i.e. target) length of the decode.
            where S is the source sequence length, T is the target sequence length, N is the
            batch size, E is the feature number
        Examples:
            >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        """

        is_batched = src.dim() == 3
        if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
            raise RuntimeError("the batch number of src and tgt must be equal")
        elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
            raise RuntimeError("the batch number of src and tgt must be equal")

        if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
            raise RuntimeError(
                "the feature number of src and tgt must be equal to d_model"
            )

        memory = self.encoder(
            src, mask=src_mask, src_key_padding_mask=src_key_padding_mask
        )
        output = self.decoder(
            tgt,
            memory,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
        )
        return output


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)


In [None]:
env = MemoryEnv()

model = MemoryTransformer(transformer_params=transformer_params, generator_params=env.generator_params)

src = torch.rand((5, 8, 4))
tgt = torch.rand((1, 8, 4))
out = model(src, tgt)


In [None]:
model.

In [None]:
env.generator_params

In [None]:
out[:,0,:].shape

In [None]:
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
from memory.environment.generator import OQAGenerator


class TransformerModel(nn.Module):
    def __init__(
        self,
        d_model: int,
        nhead: int,
        d_hid: int,
        nlayers: int,
        generator_params: dict,
        policy_type: str,
        dropout: float = 0.5,
    ):
        super().__init__()
        self.model_type = "Transformer"
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.generator_params = generator_params

        oqag = OQAGenerator(**generator_params)

        self.vocab = (
            ["<pad>", "<mask>"]
            + sorted(oqag.heads)
            + sorted(oqag.relations)
            + sorted(oqag.tails)
        )

        self.ntoken = len(self.vocab)

        self.encoder = nn.Embedding(self.ntoken, d_model, padding_idx=0)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def tokenize(self) -> None:
        oqag = OQAGenerator(**self.generator_params)
        vocab = (
            special_tokens
            + sorted(oqag.heads)
            + sorted(oqag.relations)
            + sorted(oqag.tails)
        )
        ntoken = len(vocab)

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)


In [None]:
oqag = OQAGenerator()


In [None]:
sorted(oqag.heads) + sorted(oqag.relations) + sorted(oqag.tails)


In [None]:
len(({"<pad>": 0, "<mask>": 1}))


In [None]:
from memory.environment.gym import MemoryEnv

env = MemoryEnv()


In [None]:
len(env.oqag.heads) + len(env.oqag.relations) + len(env.oqag.names)


In [None]:
env.generator_params


In [None]:
env.oqag.heads


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ntokens = 100
emsize = 200  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)


In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(10,10))

pos_encoder = PositionalEncoding(10, dropout, max_len=)
for row in pos_encoder.pe.squeeze():
    plt.plot(row)

In [None]:
pos_encoder = PositionalEncoding(emsize, dropout)
import matplotlib.pyplot as plt

for i in range(5):
    plt.plot(pos_encoder(torch.tensor([i])).squeeze())


In [None]:
import torch

torch.triu(torch.ones(5, 5) * float("-inf"), diagonal=1)


In [None]:
from torch import nn
import torch

embs = nn.Embedding(2, 3, padding_idx=0)
fn = nn.Linear(3, 3)
fn2 = nn.Linear(3, 3)
optimizer = torch.optim.Adam(
    params=list(embs.parameters()) + list(fn.parameters()) + list(fn2.parameters())
)


In [None]:
for _ in range(10000):
    # loss = ((fn(embs(torch.tensor(1))) - torch.tensor([1,1,1]))**2).sum()
    loss1 = ((fn(embs(torch.tensor(1))) - torch.tensor([2, 2, 2])) ** 2).sum()
    loss2 = ((fn2(embs(torch.tensor(1))) - torch.tensor([1, 1, 1])) ** 2).sum()

    loss = loss1 + loss2

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


In [None]:
fn(embs(torch.tensor(1)))


In [None]:
embs(torch.tensor(1)), fn.weight


In [None]:
fn(embs(torch.tensor(1)))


In [None]:
embs(torch.tensor(1))


In [None]:
embs(torch.tensor(1))


In [None]:
optimizer.param_groups
