In [38]:
import string

import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(device)

True
1
cpu


In [61]:
EOS = "eos"
PAD_1 = "pad1"
UNK = "unk"
PAD_2 = "pad2"

CHAR_TOKENS: list[str] = list(string.printable) + [EOS, PAD_1, UNK, PAD_2]
NGRAM_SIZE: int = 3
HIDDEN_SIZE: int = 768
MAX_SEQ_LEN = 16


num_chars = len(CHAR_TOKENS)
char_to_idx = {c: i for i, c in enumerate(CHAR_TOKENS)}


def tokenize(seq: str):
    seq = list(seq)
    # Pad such that len(seq) is divisible by NGRAM_SIZE
    if len(seq) % NGRAM_SIZE > 0:
        seq += [PAD_1] * (NGRAM_SIZE - (len(seq) % NGRAM_SIZE))
    seq += [EOS] * NGRAM_SIZE
    return torch.tensor([char_to_idx[c] for c in seq])


def collate(tokenized_seqs: list[torch.tensor]):
    """Pad short seqs, truncate long seqs.
    """
    tokenized_seqs = [tokenize(x)[: MAX_SEQ_LEN - 1] for x in example_data]
    max_len = max(x.shape[-1] for x in tokenized_seqs)
    result = torch.full(
        size=[len(tokenized_seqs), max_len], fill_value=char_to_idx[PAD_2], dtype=torch.int
    )
    for i, x in enumerate(tokenized_seqs):
        result[i, 0 : len(x)] = x
    return result


example_data = ["Hi..", "This is a second sentence."]
collate([tokenize(s) for s in example_data])

tensor([[ 43,  18,  75,  75, 101, 101, 100, 100, 100, 103, 103, 103, 103, 103,
         103],
        [ 55,  17,  18,  28,  94,  18,  28,  94,  10,  94,  28,  14,  12,  24,
          23]], dtype=torch.int32)

In [4]:
# An embedding table for each slot in the the ngram, (e.g. 0, 1, 2 for a NGRAM_SIZE=3).
ngram_embeddings = [
    torch.nn.Embedding(num_embeddings=num_chars, embedding_dim=HIDDEN_SIZE)
    for _ in range(NGRAM_SIZE)
]

In [5]:
def get_input_embeddings(seq: str):
    # Pad sequence if not divisble by NGRAM_SIZE.
    seq = list(seq) + (NGRAM_SIZE - len(seq) % NGRAM_SIZE) * [PAD_1]
    result = []
    for ngram_slot_idx in range(NGRAM_SIZE):
        ngram_slot_chars = seq[ngram_slot_idx : len(seq) : NGRAM_SIZE]
        ngram_slot_char_idxs = torch.tensor([char_to_idx[c] for c in ngram_slot_chars])
        ngram_slot_embeddings = ngram_embeddings[ngram_slot_idx](ngram_slot_char_idxs)
        result.append(ngram_slot_embeddings)
    result = torch.stack(result).sum(dim=0)
    return result


seq = "This is a test sentence. "
get_input_embeddings(seq)

tensor([[-0.4735,  2.0881,  4.7124,  ...,  0.6942,  2.9293, -1.5267],
        [-1.1553,  0.1940, -0.4858,  ..., -2.4012,  3.5435, -0.4625],
        [-1.3504, -1.3928, -0.0316,  ...,  0.6737,  1.0648,  1.6953],
        ...,
        [ 0.8786,  1.0970,  3.6359,  ...,  0.2867, -0.8149, -0.2832],
        [ 0.0515,  0.0914,  0.8045,  ...,  1.6901,  0.3185, -0.9622],
        [ 1.4446,  0.9301, -1.5488,  ...,  1.5180,  0.0222,  1.8606]],
       grad_fn=<SumBackward1>)

In [66]:
from transformers import RobertaConfig, RobertaForMaskedLM

model = RobertaForMaskedLM(
    config=RobertaConfig(
        vocab_size=2,  # won't use
        hidden_size=HIDDEN_SIZE,  # default 768
        max_position_embeddings=514,
        num_attention_heads=12,
        num_hidden_layers=6,
        type_vocab_size=1,
        attention_probs_dropout_prob=0,
        hidden_dropout_prob=0,
    )
)

In [67]:
cn_embeddings = torch.rand([1, 5, HIDDEN_SIZE])
lm_embeddings = model.roberta.forward(inputs_embeds=cn_embeddings).last_hidden_state
# then decode the lm embeddings back to chars