In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

# 1. Base case tokenization

In [3]:
from transformers import DistilBertTokenizer

plaintext = 'happy'
cyphertext = 'ogatzAKtNKEvabbzM7kUig=='

tokenizer_ckpt = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'
tokenizer = DistilBertTokenizer.from_pretrained(tokenizer_ckpt)
tokenizer.vocab_size

30522

In [4]:
tokenizer.tokenize(plaintext)

['happy']

In [5]:
tokenizer(plaintext)['input_ids']

[101, 3407, 102]

In [6]:
# EXTRA IMPORTANT GOTCHA HERE! WHY IS THE ID 1?
token = tokenizer.tokenize(plaintext)[0]
id = tokenizer(plaintext)['input_ids'][1]

plain_to_index = { token: id }
cypher_to_index = { cyphertext: 1 }

plain_to_index, cypher_to_index

({'happy': 3407}, {'ogatzAKtNKEvabbzM7kUig==': 1})

In [7]:
plain_to_cypher = { token: cyphertext }
cypher_to_plain = { cyphertext: token }

plain_to_cypher, cypher_to_plain

({'happy': 'ogatzAKtNKEvabbzM7kUig=='}, {'ogatzAKtNKEvabbzM7kUig==': 'happy'})

# 2. Token masking and remapping

In [8]:
from transformers import DistilBertForSequenceClassification

model_ckpt = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'

model = DistilBertForSequenceClassification.from_pretrained(model_ckpt)
model.config.dim

768

In [9]:
model.distilbert.embeddings

Embeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [10]:
word_embeddings = model.distilbert.embeddings.word_embeddings
word_embeddings

Embedding(30522, 768, padding_idx=0)

In [11]:
plain_embeddings = {}

for k, v in plain_to_index.items():
    print(f'{k} -> {v}')

    emb = word_embeddings(torch.tensor(v))
    print('Embeddings shape:', emb.shape)

    plain_embeddings[k] = emb.tolist()


happy -> 3407
Embeddings shape: torch.Size([768])


In [12]:
cypher_embeddings = {}

for k, v in cypher_to_index.items():
    print(f'{k} -> {v}')

    plain = cypher_to_plain[k]

    emb = word_embeddings(torch.tensor(plain_to_index[plain]))
    print('Embeddings shape:', emb.shape)

    cypher_embeddings[k] = emb.tolist()


ogatzAKtNKEvabbzM7kUig== -> 1
Embeddings shape: torch.Size([768])


In [13]:
plain_embeddings[plaintext] == cypher_embeddings[cyphertext]

True

# 3. Preprocessor

## 3.1. Dataset for preprocessor

In [14]:
class CypherEmbeddingDataset(Dataset):
    def __init__(self, cypher_to_index, cypher_embeddings):
        self.cypher_to_index = cypher_to_index
        self.cyphers = list(cypher_embeddings.keys())
        self.embeddings = [torch.tensor(cypher_embeddings[cypher]) for cypher in self.cyphers]

    def __len__(self):
        return len(self.cyphers)

    def __getitem__(self, idx):
        cypher = self.cyphers[idx]
        cypher_index = self.cypher_to_index[cypher]
        embedding = self.embeddings[idx]
        return torch.tensor(cypher_index), embedding

In [15]:
dataset = CypherEmbeddingDataset(cypher_to_index, cypher_embeddings)

dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

for index, embeddings in dataloader:
    print("Cypher Index:", index)
    print("Embeddings:", embeddings.shape)

Cypher Index: tensor([1])
Embeddings: torch.Size([1, 768])


## 2.2. Class definition and training

In [16]:
class Preprocessor(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Preprocessor, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, indices):
        return self.embedding(indices)

In [17]:
preprocessor = Preprocessor(vocab_size=tokenizer.vocab_size, embedding_dim=model.config.dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(preprocessor.parameters(), lr=0.001)

EPOCHS = 6000

for epoch in range(EPOCHS):
    total_loss = 0
    for idx, embedding in dataloader:
        optimizer.zero_grad()
        outputs = preprocessor(idx)
        loss = criterion(outputs, embedding)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    if (epoch + 1) % 600 == 0:
        print(f"Epoch [{epoch + 1}/{EPOCHS}], Loss: {total_loss:.4f}")

Epoch [600/6000], Loss: 0.4065
Epoch [1200/6000], Loss: 0.1609
Epoch [1800/6000], Loss: 0.0621
Epoch [2400/6000], Loss: 0.0227
Epoch [3000/6000], Loss: 0.0077
Epoch [3600/6000], Loss: 0.0023
Epoch [4200/6000], Loss: 0.0005
Epoch [4800/6000], Loss: 0.0001
Epoch [5400/6000], Loss: 0.0000
Epoch [6000/6000], Loss: 0.0000


In [20]:
embeddings = word_embeddings(torch.tensor(plain_to_index[plaintext]))
output = preprocessor(torch.tensor(cypher_to_index[cyphertext]))
torch.allclose(embeddings, output, rtol=1e-06, atol=1e-01)

True

In [21]:
torch.allclose(embeddings, output, rtol=1e-06, atol=1e-02)

False

In [22]:
Path("models").mkdir(exist_ok=True)
torch.save(obj=preprocessor.state_dict(), f='./models/preprocessor.pt')