# Text Joint Embedding Predictive Architecture

* [JEPA](https://arxiv.org/pdf/2306.02572)

## Installs

In [None]:
!pip install x-transformers



## Imports

In [None]:
import copy
from typing import *

import numpy as np
import torch
import torch.nn as nn
import transformers
from transformers import BertTokenizer, BertModel
from x_transformers import Encoder, Decoder
from x_transformers.x_transformers import ScaledSinusoidalEmbedding

## Tokeniser

* Tokenisers split words into chunks to be processed individually be a language model.
* We employ subword tokenisation using BERT: [explanation](https://h2o.ai/wiki/bert/#:~:text=BERT%2C%20short%20for%20Bidirectional%20Encoder,framework%20for%20natural%20language%20processing.), [paper](https://arxiv.org/abs/1810.04805). (See example below.)

### Config

In [None]:
# Load pretrained BERT tokeniser
tokeniser = BertTokenizer.from_pretrained("bert-base-uncased")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


### Text Tokenisation Example

In [None]:
# Example text
text: str = "This is an example sentence for subword tokenization."
print(f"{text=}")

# Tokenize a sentence
tokens: List[str] = tokeniser.tokenize(text)
print(f"{tokens=}")

text='This is an example sentence for subword tokenization.'
tokens=['this', 'is', 'an', 'example', 'sentence', 'for', 'sub', '##word', 'token', '##ization', '.']


### Text Encoding

* Words, or tokens, can't be processed by language models - they need to be converted into numbers.
* We need a mapping from each token in out dictionary to a given id (number) - these numbers are what get processed by the AI.

In [None]:
# Tokenize the text and get input IDs
batch_encoding: transformers.tokenization_utils_base.BatchEncoding = tokeniser(
    [
        text,
        # "this is another sentence",
    ],
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=128,
)
token_ids: torch.Tensor = batch_encoding["input_ids"]  # Token IDs
print(f"{token_ids.shape=}")
print(f"{token_ids=}")

token_ids.shape=torch.Size([1, 13])
token_ids=tensor([[  101,  2023,  2003,  2019,  2742,  6251,  2005,  4942, 18351, 19204,
          3989,  1012,   102]])


## T-JEPA

See [Joint Embedding Prediictive Architrecture (JEPA)](https://arxiv.org/abs/2306.02572), [I-JEPA](https://arxiv.org/abs/2301.08243).

### Config

In [None]:
# TOKENISATION
vocab_size: int = tokeniser.vocab_size
# TOKEN EMBEDDING
embed_dim: int = 512  # 768  # BERT hidden dimension
# CONTEXT/TARGET GENERATION
target_scale_interval: Tuple[float, float] = (0.15, 0.2)
context_scale_interval: Tuple[float, float] = (0.7, 0.9)
# TRANSFORMER ENCODER/DECODER
num_layers: int = 6  # Number of layers in the transformer
num_heads: int = 8  # Number of attention heads in the transformer
layer_dropout: float = 0.0

print(f"{vocab_size=}")
print(f"{embed_dim=}")

vocab_size=30522
embed_dim=512


### Layers

In [None]:
embedding_layer = nn.Embedding(
    vocab_size, embed_dim
) # Learns an `embed_dim`-dimensional (numerical) representation of each token in the vocabulary

student_encoder = Encoder(
    dim=embed_dim,
    heads=num_heads,
    depth=num_layers,
    layer_dropout=layer_dropout,
)

teacher_encoder = copy.deepcopy(student_encoder)  # .cuda()  # copy student encoder

decoder = Decoder(
    dim=embed_dim,
    depth=num_layers // 2,
    heads=num_heads // 2,
    layer_dropout=layer_dropout,
)

pos_embedding = ScaledSinusoidalEmbedding(embed_dim)

In [None]:
mask_token = nn.Parameter(torch.randn(1, 1, embed_dim))
nn.init.trunc_normal_(mask_token, 0.02)
None

## Embed Token Ids

Learn a numerical representation for the "meaning" of each token

In [None]:
token_embeddings: torch.Tensor = embedding_layer(token_ids)

# Add positional embeddings
# NOTE: Transformer networks have no positional awareness, so we need to tell the network the order of the tokens with positional embeddings
token_embeddings = token_embeddings + pos_embedding(token_embeddings)
print(f"{token_embeddings.shape=}")

token_embeddings.shape=torch.Size([1, 13, 512])


Transformers take in a fixed-length input, so we must pad short sentences with a "pad" token.

In [None]:
tokeniser.pad_token_id, tokeniser.pad_token

(0, '[PAD]')

## Extract Targets and Contexts

In the JEPA framework, we mask the targets and use the contexts to reconstruct the targets.

In [None]:
target_prob_range: Tuple[float, float] = (0.15, 0.35)

target_prob: float = np.random.uniform(
    low=target_prob_range[0], high=target_prob_range[1]
)
print(f"{target_prob=}")

target_prob=0.1684745906939956


In [None]:
target_scale: float = np.random.uniform(
    low=target_scale_interval[0], high=target_scale_interval[1]
)
context_scale: float = np.random.uniform(
    low=context_scale_interval[0], high=context_scale_interval[1]
)

target_indices: torch.Tensor = torch.bernoulli(
    torch.full(token_ids.shape, target_scale)  # target_probability_matrix
).bool()
context_indices: torch.Tensor = torch.bernoulli(
    torch.full(token_ids.shape, context_scale)  # context_probability_matrix
).bool()
print(f"{target_indices.shape=}")
print(f"{target_indices=}")
print()
print(f"{context_indices.shape=}")
print(f"{context_indices=}")
# NOTE: The targets and contexts are allowed to overlap

target_indices.shape=torch.Size([1, 13])
target_indices=tensor([[False, False,  True, False,  True, False, False, False, False, False,
         False, False, False]])

context_indices.shape=torch.Size([1, 13])
context_indices=tensor([[ True,  True,  True, False,  True,  True,  True,  True,  True,  True,
          True,  True,  True]])


In [None]:
target_embeddings: torch.Tensor = token_embeddings[
    None, target_indices
]  # (batch_size, num_target_tokens, embed_dim)
context_embeddings: torch.Tensor = token_embeddings[
    None, context_indices
]  # (batch_size, num_context_tokens, embed_dim)

In [None]:
target_encoding: torch.Tensor = teacher_encoder(
    target_embeddings
)  # (batch_size, num_target_tokens, embed_dim)
context_encoding: torch.Tensor = student_encoder(
    context_embeddings
)  # (batch_size, num_context_tokens, embed_dim)
target_embeddings.shape, context_embeddings.shape

(torch.Size([1, 2, 512]), torch.Size([1, 12, 512]))

## Latent Reconstruction

Use the context to predict the targets

In [None]:
batch_dim, num_patches, _ = target_embeddings.shape
target_masks: torch.Tensor = mask_token.repeat(batch_dim, num_patches, 1)
print(f"{target_masks.shape=}")
assert target_masks.shape == target_embeddings.shape

target_masks.shape=torch.Size([1, 2, 512])


In [None]:
# NOTE: Targets and contexts contain positional information
# This positional information is un-affected by the concatenation
x: torch.Tensor = torch.cat([context_embeddings, target_masks], dim=1)
print(f"{x.shape=}")

# Decode
x = decoder(x) # Self-attention

# Return the output corresponding to target tokens, i.e., the last len(target_masks) tokens
prediction: torch.Tensor = x[:, -target_masks.shape[1] :, :]
print(f"{prediction.shape=}")

x.shape=torch.Size([1, 14, 512])
prediction.shape=torch.Size([1, 2, 512])


## Calculate Loss

In [None]:
criterion = nn.MSELoss()

loss: torch.Tensor = criterion(prediction, target_embeddings)
print(f"{loss=}")

loss=tensor(1.9379, grad_fn=<MseLossBackward0>)
