# Text AutoEncoders

## Imports

In [40]:
import copy
from typing import *

import numpy as np
import torch
import torch.nn as nn
import transformers
from transformers import BertTokenizer, BertModel
from x_transformers import Encoder, Decoder
from x_transformers.x_transformers import ScaledSinusoidalEmbedding

from model.predictor import Predictor

## Tokeniser

Subword tokenisation using BERT

### Config

In [2]:
# Load pretrained BERT model and tokeniser
tokeniser = BertTokenizer.from_pretrained("bert-base-uncased")



### Text Tokenisation Example

In [3]:
# Example text
text: str = "This is an example sentence for subword tokenization."
print(f"{text=}")

# Tokenize a sentence
tokens: List[str] = tokeniser.tokenize(text)
print(f"{tokens=}")

text='This is an example sentence for subword tokenization.'
tokens=['this', 'is', 'an', 'example', 'sentence', 'for', 'sub', '##word', 'token', '##ization', '.']


### Text Encoding

In [286]:
# Tokenize the text and get input IDs and attention mask
batch_encoding: transformers.tokenization_utils_base.BatchEncoding = tokeniser(
    [
        text,
        # "this is another sentence",
    ],
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=128,
)
token_ids: torch.Tensor = batch_encoding["input_ids"]  # Token IDs
print(f"{token_ids.shape=}")
print(f"{token_ids=}")
# token_type_ids: torch.Tensor = batch_encoding["token_type_ids"]
# print(f"{token_type_ids=}")
# attention_mask: torch.Tensor = batch_encoding[
#     "attention_mask"
# ]  # Mask to ignore padding tokens
# print(f"{attention_mask=}")

token_ids.shape=torch.Size([1, 13])
token_ids=tensor([[  101,  2023,  2003,  2019,  2742,  6251,  2005,  4942, 18351, 19204,
          3989,  1012,   102]])


## T-JEPA

### Config

In [16]:
# TOKENISATION
vocab_size: int = tokeniser.vocab_size
# TOKEN EMBEDDING
embed_dim: int = 512  # 768  # BERT hidden dimension
# CONTEXT/TARGET GENERATION
target_scale_interval: Tuple[float, float] = (0.15, 0.2)
context_scale_interval: Tuple[float, float] = (0.7, 0.9)
# TRANSFORMER ENCODER/DECODER
num_layers: int = 6  # Number of layers in the transformer
num_heads: int = 8  # Number of attention heads in the transformer
layer_dropout: float = 0.0

print(f"{vocab_size=}")
print(f"{embed_dim=}")

vocab_size=30522
embed_dim=512


### Layers

In [63]:
embedding_layer = nn.Embedding(vocab_size, embed_dim)

# TODO: Add positional embeddings
student_encoder = Encoder(
    dim=embed_dim,
    heads=num_heads,
    depth=num_layers,
    layer_dropout=layer_dropout,
)

teacher_encoder = copy.deepcopy(student_encoder)  # .cuda()  # copy student encoder

decoder = Decoder(
    dim=embed_dim,
    depth=num_layers // 2,
    heads=num_heads // 2,
    layer_dropout=layer_dropout,
)
# predictor = Predictor(
#     embed_dim=embed_dim,
#     num_heads=num_heads // 2,
#     depth=num_layers // 2,
#     layer_dropout=layer_dropout,
# )

pos_embedding = ScaledSinusoidalEmbedding(embed_dim)

In [54]:
mask_token = nn.Parameter(torch.randn(1, 1, embed_dim))
nn.init.trunc_normal_(mask_token, 0.02)
None

In [31]:
token_ids: torch.Tensor

token_embeddings: torch.Tensor = embedding_layer(token_ids)

# Add positional embeddings
token_embeddings = token_embeddings + pos_embedding(token_embeddings)
print(f"{token_embeddings.shape=}")

token_embeddings.shape=torch.Size([1, 13, 512])


In [64]:
token_ids.shape

torch.Size([1, 13])

In [91]:
token_ids.shape[1]

13

In [278]:
torch.ones(4).shape[0]

4

In [293]:
tokeniser.pad_token_id, tokeniser.pad_token

(0, '[PAD]')

In [291]:
torch.tensor([101, 2023, 2003, 2178, 6251, 102, 0, 0, 0, 0, 0, 0, 0])

tensor([ 101, 2023, 2003, 2178, 6251,  102,    0,    0,    0,    0,    0,    0,
           0])

In [259]:
token_ids: torch.Tensor

target_prob_range: Tuple[float, float] = (0.15, 0.35)

target_prob: float = np.random.uniform(
    low=target_prob_range[0], high=target_prob_range[1]
)
print(f"{target_prob=}")

target_indices: torch.Tensor

target_prob=0.29749277535754115


3

In [298]:
torch.randperm(13)[:4].tolist()

[7, 1, 9, 8]

In [79]:
target_scale: float = np.random.uniform(
    low=target_scale_interval[0], high=target_scale_interval[1]
)
context_scale: float = np.random.uniform(
    low=context_scale_interval[0], high=context_scale_interval[1]
)

target_indices: torch.Tensor = torch.bernoulli(
    torch.full(token_ids.shape, target_scale)  # target_probability_matrix
).bool()
context_indices: torch.Tensor = torch.bernoulli(
    torch.full(token_ids.shape, context_scale)  # context_probability_matrix
).bool()
print(f"{target_indices.shape=}")
print(f"{target_indices=}")
print()
print(f"{context_indices.shape=}")
print(f"{context_indices=}")
# NOTE: The targets and contexts are allowed to overlap

target_indices.shape=torch.Size([1, 13])
target_indices=tensor([[False,  True,  True, False, False,  True, False, False, False, False,
          True,  True, False]])

context_indices.shape=torch.Size([1, 13])
context_indices=tensor([[False,  True,  True, False,  True,  True,  True,  True,  True,  True,
          True,  True,  True]])


In [65]:
target_embeddings: torch.Tensor = token_embeddings[
    None, target_indices
]  # (batch_size, num_target_tokens, embed_dim)
context_embeddings: torch.Tensor = token_embeddings[
    None, context_indices
]  # (batch_size, num_context_tokens, embed_dim)

In [67]:
target_encoding: torch.Tensor = teacher_encoder(
    target_embeddings
)  # (batch_size, num_target_tokens, embed_dim)
context_encoding: torch.Tensor = student_encoder(
    context_embeddings
)  # (batch_size, num_context_tokens, embed_dim)
target_embeddings.shape, context_embeddings.shape

(torch.Size([1, 2, 512]), torch.Size([1, 12, 512]))

Use the context to predict the targets

In [58]:
batch_dim, num_patches, _ = target_embeddings.shape
target_masks: torch.Tensor = mask_token.repeat(batch_dim, num_patches, 1)
print(f"{target_masks.shape=}")
assert target_masks.shape == target_embeddings.shape

target_masks.shape=torch.Size([1, 2, 512])


In [59]:
# NOTE: Targets and contexts contain positional information
# This positional information is un-affected by the concatenation
x: torch.Tensor = torch.cat([context_embeddings, target_masks], dim=1)
print(f"{x.shape=}")

# Decode
x = decoder(x)
# Return the output corresponding to target tokens, i.e., the last len(target_masks) tokens
prediction: torch.Tensor = x[:, -target_masks.shape[1] :, :]
print(f"{prediction.shape=}")

x.shape=torch.Size([1, 14, 512])
prediction.shape=torch.Size([1, 2, 512])


In [60]:
# prediction: torch.Tensor = predictor(
#     context_encoding=context_encoding,
#     target_masks=target_masks,
# )
# print(f"{prediction.shape=}")

prediction.shape=torch.Size([1, 2, 512])


In [62]:
criterion = nn.MSELoss()

loss: torch.Tensor = criterion(prediction, target_embeddings)
print(f"{loss=}")

loss=tensor(2.0081, grad_fn=<MseLossBackward0>)


## T-JEPA Class

## Vision-Language Models?

# Other Stuff

In [16]:
bert_model: BertModel = BertModel.from_pretrained("bert-base-uncased")

In [28]:
bert_embeddings: (
    transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions
) = bert_model(input_ids=input_ids, attention_mask=attention_mask)
print(f"{bert_embeddings.keys()=}")
print(f"{bert_embeddings.last_hidden_state.shape=}")

bert_embeddings.keys()=odict_keys(['last_hidden_state', 'pooler_output'])
bert_embeddings.last_hidden_state.shape=torch.Size([1, 13, 768])


In [37]:
tokeniser.pad_token_id, tokeniser.cls_token_id, tokeniser.sep_token_id, tokeniser.mask_token_id

(0, 101, 102, 103)

In [36]:
# Masking some tokens
def mask_tokens(
    input_ids: torch.Tensor,
    tokeniser: BertTokenizer,
    masking_prob: float = 0.15,
    ignore_index: int = -100,  # nn.CrossEntropyLoss default ignore_index (for non-masked tokens)
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Tokens to be predicted are masked with the [MASK] token (e.g. 103).
    Their corresponding labels will not be ignored.
    All other other tokens are not masked, and their labels will be ignored (e.g. set to -100).
    ([CLS], [SEP], and [PAD]) are not predicted and will not be masked.
    """
    # Initialise labels as ground thruths (input_ids)
    # NOTE: Once we mask tokens (to be predicted), the unmasked tokens will be ignored (i.e. their labels will be set to -100)
    labels: torch.Tensor = input_ids.clone()

    # Create mask for tokens to be predicted (i.e. mask to replace token ids with [MASK] token id)
    probability_matrix: torch.Tensor = torch.full(labels.shape, masking_prob)
    masked_indices: torch.Tensor = torch.bernoulli(probability_matrix).bool()

    # Avoid masking [CLS], [SEP], and [PAD] tokens
    masked_indices[input_ids == tokeniser.cls_token_id] = False
    masked_indices[input_ids == tokeniser.sep_token_id] = False
    masked_indices[input_ids == tokeniser.pad_token_id] = False

    # Replace masked tokens with [MASK] token id
    input_ids[masked_indices] = tokeniser.mask_token_id

    # Labels are the original tokens where we mask
    labels[~masked_indices] = (
        ignore_index  # NOTE: -100 will be ignored by the criterion (i.e. the unmasked tokens not being predicted)
    )

    return input_ids, labels


masked_input_ids: torch.Tensor
labels: torch.Tensor
masked_input_ids, labels = mask_tokens(
    input_ids=batch_encoding["input_ids"], tokeniser=tokeniser
)
print(f"{masked_input_ids=}")
print(f"{labels=}")

masked_input_ids=tensor([[  101,  2023,  2003,   103,  2742,   103,   103,   103, 18351, 19204,
          3989,   103,   102]])
labels=tensor([[-100, -100, -100, -100, -100, 6251,  103, -100, -100, -100, -100, -100,
         -100]])


In [10]:
from x_transformers import Encoder


class TextMaskedAutoencoder(nn.Module):
    def __init__(
        self,
        bert_model: BertModel,
        embed_dim: int = 64,
        enc_depth: int = 8,
        num_heads: int = 8,
        layer_dropout: float = 0.0,
    ):
        super().__init__()
        self.bert_model = (
            bert_model  # This provides token embeddings (the BERT backbone)
        )

        self.layer_dropout = layer_dropout

        self.encoder = Encoder(
            dim=embed_dim,
            heads=num_heads,
            depth=enc_depth,
            layer_dropout=self.layer_dropout,
        )

        self.fc = nn.Linear(
            embed_dim, tokenizer.vocab_size
        )  # Output layer to predict masked tokens

    def forward(
        self,
        input_ids: torch.Tensor,  # Shape: (batch_size, seq_len)
        attention_mask: torch.Tensor,  # Shape: (batch_size, seq_len)
    ) -> torch.Tensor:
        # Get embeddings from BERT backbone
        with torch.no_grad():  # Freeze BERT parameters if needed
            outputs: (
                transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions
            ) = self.bert_model(input_ids, attention_mask=attention_mask)
            hidden_states: torch.Tensor = (
                outputs.last_hidden_state
            )  # Shape: (batch_size, seq_len, hidden_dim)

        # Pass embeddings through transformer encoder
        transformer_output: torch.Tensor = self.encoder(
            hidden_states
        )  # Shape: (batch_size, seq_len)

        # Predict the tokens at masked positions
        logits: torch.Tensor = self.fc(
            transformer_output
        )  # Shape: (batch_size, seq_len, vocab_size)

        return logits

In [11]:
# Instantiate the model
hidden_dim = 768  # BERT hidden dimension
num_layers = 6  # Number of layers in the transformer encoder
num_heads = 8  # Number of attention heads

model = TextMaskedAutoencoder(
    bert_model=bert_model,
    embed_dim=hidden_dim,
    enc_depth=num_layers,
    num_heads=num_heads,
)
# Forward pass
logits = model(masked_input_ids, inputs["attention_mask"])
logits.shape

hidden_states.shape=torch.Size([1, 13, 768])
transformer_output.shape=torch.Size([1, 13, 768])


torch.Size([1, 13, 30522])

In [17]:
logits[:, 0, :]

tensor([[-0.6295,  0.0651, -0.9394,  ...,  0.5241,  0.3121, -0.9289]],
       grad_fn=<SliceBackward0>)

In [18]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss(
    ignore_index=-100
)  # Ignore non-masked tokens in the loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [21]:
help(criterion)

Help on CrossEntropyLoss in module torch.nn.modules.loss object:

class CrossEntropyLoss(_WeightedLoss)
 |  CrossEntropyLoss(weight: Optional[torch.Tensor] = None, size_average=None, ignore_index: int = -100, reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0) -> None
 |  
 |  This criterion computes the cross entropy loss between input logits
 |  and target.
 |  
 |  It is useful when training a classification problem with `C` classes.
 |  If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
 |  assigning weight to each of the classes.
 |  This is particularly useful when you have an unbalanced training set.
 |  
 |  The `input` is expected to contain the unnormalized logits for each class (which do `not` need
 |  to be positive or sum to 1, in general).
 |  `input` has to be a Tensor of size :math:`(C)` for unbatched input,
 |  :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the
 |  `K`-dimensional ca

In [22]:
criterion.ignore_index  # NOTE: Specifies a target value that is ignored and does not contribute to the input gradient.

-100

In [20]:
# Compute the loss
logits = logits.view(
    -1, tokeniser.vocab_size
)  # Flatten the output for loss computation
labels = labels.view(-1)  # Flatten the labels

print(f"{logits.shape=}")
print(f"{labels.shape=}")

loss: torch.Tensor = criterion(
    input=logits,
    target=labels,  # Contains -100 for non-masked tokens
)
print(f"{loss=}")

logits.shape=torch.Size([13, 30522])
labels.shape=torch.Size([13])
loss=tensor(10.2919, grad_fn=<NllLossBackward0>)
