In [1]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from positional_encodings import PositionalEncoding
import os
from sentence_vae_decoder import SentenceDecoder
from sentence_vae_encoder import SentenceEncoder

In [2]:
# import gpt2 tokenizer
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# create a batch of samples
samples = ["hello world", "hello universe"]

In [4]:
# apply the tokenizer
tokens = tokenizer(samples, return_tensors="pt")
tokens

{'input_ids': tensor([[31373,   995],
        [31373,  6881]]), 'attention_mask': tensor([[1, 1],
        [1, 1]])}

In [10]:
@dataclass
class SentenceEncoderConfig:
    word_embed_proj_dim: Optional[int] = 64
    hidden_size: int = 128
    vocab_size: int = tokenizer.vocab_size
    max_seq_len: int = 2
    num_hidden_layers: int = 2
    num_attention_heads: int = 2
    pad_id: int = 0
    dropout: float = 0.0
    load_embedding_weights: bool = False
    embedding_weights_path: Optional[str] = None
    do_finetune: bool = False

torch.manual_seed(42)
config = SentenceEncoderConfig()

In [6]:
sentence_vae_encoder = SentenceEncoder(**config.__dict__)

In [7]:
sentence_embedding = sentence_vae_encoder(tokens["input_ids"])

In [8]:
tokenizer.bos_token_id

50256

In [11]:
sentence_vae_decoder = SentenceDecoder(**config.__dict__)

In [14]:
tensor = torch.Tensor([50256]).long().unsqueeze(0)
tensor.shape

torch.Size([1, 1])

In [16]:
embedding = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_id)

In [17]:
_, seq_len = tensor.shape
attention_mask = (tensor != 0)
attention_mask = ~attention_mask.to(torch.bool)

In [23]:

input_embeddings= embedding(tensor)
positional_encoding = PositionalEncoding(config.hidden_size, seq_len)

In [26]:
embeddings = input_embeddings + positional_encoding(seq_len)

In [29]:
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len) == -torch.inf

In [30]:
decoder_layer = nn.TransformerDecoderLayer(
            d_model=config.hidden_size,
            nhead=config.num_attention_heads,
            dropout=config.dropout,
            dim_feedforward=config.hidden_size * 2,
            batch_first=True)

decoder = nn.TransformerDecoder(
            decoder_layer=decoder_layer,
            num_layers=config.num_hidden_layers
        )

In [32]:
embeddings.shape

torch.Size([1, 1, 128])

In [31]:
output = decoder(embeddings, causal_mask)

AssertionError: For batched (3-D) `query`, expected `key` and `value` to be 3-D but found 2-D and 2-D tensors respectively