In [1]:
import warnings
warnings.filterwarnings('ignore')

In [70]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

# 1. Base case tokenization

In [3]:
from transformers import DistilBertTokenizer

sentence = 'happy'

tokenizer_ckpt = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'
tokenizer = DistilBertTokenizer.from_pretrained(tokenizer_ckpt)

In [4]:
tokenizer.tokenize(sentence)

['happy']

In [5]:
tokenizer(sentence)['input_ids']

[101, 3407, 102]

In [6]:
# EXTRA IMPORTANT GOTCHA HERE! WHY IS THE ID 1?

token = tokenizer.tokenize(sentence)[0]
id = tokenizer(sentence)['input_ids'][1]

token_to_index = { token: id }
token_to_index

{'happy': 3407}

# 2. Tokenizer emulator

In [7]:
from transformers import DistilBertForSequenceClassification

base_model_ckpt = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'

base_model = DistilBertForSequenceClassification.from_pretrained(base_model_ckpt)

In [8]:
base_model.distilbert.embeddings

Embeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [9]:
word_embeddings = base_model.distilbert.embeddings.word_embeddings
word_embeddings

Embedding(30522, 768, padding_idx=0)

In [23]:
token_embeddings = {}

for k, v in token_to_index.items():
    print(f'{k} -> {v}')

    emb = word_embeddings(torch.tensor(v))
    print('Embeddings shape:', emb.shape)

    token_embeddings[k] = emb.tolist()


happy -> 3407
Embeddings shape: torch.Size([768])


## 2.1. Dataset for tokenizer emulator

In [34]:
class TokenEmbeddingDataset(Dataset):
    def __init__(self, token_to_index, token_embeddings):
        self.token_to_index = token_to_index
        self.tokens = list(token_embeddings.keys())
        self.embeddings = [torch.tensor(token_embeddings[token]) for token in self.tokens]

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        token = self.tokens[idx]
        token_index = self.token_to_index[token]
        embedding = self.embeddings[idx]
        return torch.tensor(token_index), embedding

In [35]:
dataset = TokenEmbeddingDataset(token_to_index, token_embeddings)

dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

for token_indices, embeddings in dataloader:
    print("Token Indices:", token_indices)
    print("Embeddings:", embeddings.shape)

Token Indices: tensor([3407])
Embeddings: torch.Size([1, 768])


## 2.2. Defining and training the tokenizer emulator

In [36]:
class TokenEmulator(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(TokenEmulator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, token_indices):
        return self.embedding(token_indices)

In [47]:
model = TokenEmulator(vocab_size=tokenizer.vocab_size, embedding_dim=base_model.config.dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5000

for epoch in range(num_epochs):
    total_loss = 0
    for token_idx, embedding in dataloader:
        optimizer.zero_grad()
        outputs = model(token_idx)
        loss = criterion(outputs, embedding)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    if (epoch + 1) % 500 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

Epoch [500/5000], Loss: 0.4592
Epoch [1000/5000], Loss: 0.2081
Epoch [1500/5000], Loss: 0.0907
Epoch [2000/5000], Loss: 0.0372
Epoch [2500/5000], Loss: 0.0141
Epoch [3000/5000], Loss: 0.0048
Epoch [3500/5000], Loss: 0.0014
Epoch [4000/5000], Loss: 0.0004
Epoch [4500/5000], Loss: 0.0001
Epoch [5000/5000], Loss: 0.0000


In [69]:
base_model_embedding = word_embeddings(torch.tensor(token_to_index['happy']))
output = model(torch.tensor(token_to_index['happy']))
torch.allclose(base_model_embedding, output, rtol=1e-06, atol=1e-01)

True

In [71]:
Path("models").mkdir(exist_ok=True)
torch.save(obj=model.state_dict(), f='./models/token-emulator.pt')

# 2. Preprocessor network

## 2.1. Class definition

In [9]:
from torch.nn import Module, Linear, ReLU, Sequential, Embedding

class Preprocessor(Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Preprocessor, self).__init__()

        self.embedding = Embedding(vocab_size, embedding_dim)
        
        self.encoder = Sequential(
            Linear(embedding_dim, 32),
            ReLU(),
            Linear(32, 16),
        )
        
        self.decoder = Sequential(
            Linear(16, 32),
            ReLU(),
            Linear(32, embedding_dim)
        )
        
        self.output = Linear(embedding_dim, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        encoded = self.encoder(embedded)
        decoded = self.decoder(encoded)
        logits = self.output(decoded)
        return logits

## 2.2. Initialization

In [None]:
VOCAB_SIZE = 7
EMBEDDING_DIM = 50

model = Preprocessor(VOCAB_SIZE, EMBEDDING_DIM)
model

In [None]:
from torch import tensor, argmax

fixed_input = tensor([[ 1, 2, 3, 4, 5, 6]])

logits = model(fixed_input)
output = argmax(logits, dim=-1)
output.shape, logits.shape, logits

## 2.3. Training

In [None]:
from torch import argmax, randint
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

BATCH_SIZE = 32
SEQ_LENGTH = 6

model.train()

criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-3)

num_epochs = 2500
for epoch in range(num_epochs):
    total_loss = 0

    inputs = randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))
    targets = inputs.clone()

    dataset = TensorDataset(inputs, targets)
    data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    for batch_inputs, batch_targets in data_loader:
        logits = model(batch_inputs)
        loss = criterion(logits.permute(0, 2, 1), batch_targets)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if (epoch + 1) % 250 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

## 2.4. Testing

In [None]:
logits = model(fixed_input)
output = argmax(logits, dim=-1)
output

## 2.5. Save model

In [14]:
from pathlib import Path
from torch import save

Path("models").mkdir(exist_ok=True)
save(obj=model.state_dict(), f='./models/preprocessor.pt')