In [24]:
from pathlib import Path

import random
import string

import torch
from torch.utils.data import Dataset
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
from transformers.modeling_outputs import MaskedLMOutput
from tokenizers import normalizers


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(device)

True
1
cuda


In [25]:
EOS = "<eos>"
CHAR_PAD = "<char_pad>"
UNK = "<unk>"
NGRAM_PAD = "<ngram_pad>"
MASK = "<mask>"

CHAR_TOKENS: list[str] = list(string.printable) + [EOS, CHAR_PAD, UNK, NGRAM_PAD, MASK]
NGRAM_SIZE: int = 3
HIDDEN_SIZE: int = 768
MAX_SEQ_LEN: int = 256
PROB_MASK: float = 0.15

DATASET = "./data/oscar.eo.txt"

num_chars = len(CHAR_TOKENS)
char_to_idx = {c: i for i, c in enumerate(CHAR_TOKENS)}
idx_to_char = {i: c for i, c in enumerate(CHAR_TOKENS)}

normalizer = normalizers.Sequence(
    [normalizers.NFD(), normalizers.Lowercase(), normalizers.StripAccents()]
)


def tokenize(seq: str):
    seq = normalizer.normalize_str(seq)
    seq = list(seq)
    # Pad such that len(seq) is divisible by NGRAM_SIZE
    if len(seq) % NGRAM_SIZE > 0:
        seq += [CHAR_PAD] * (NGRAM_SIZE - (len(seq) % NGRAM_SIZE))
    seq += [EOS] * NGRAM_SIZE
    return torch.tensor(
        [char_to_idx[c] if c in char_to_idx else char_to_idx[UNK] for c in seq]
    )


def collate(tokenized_seqs: list[torch.tensor], masking_probability: float = 0.15):
    """Pad short seqs, truncate long seqs."""
    tokenized_seqs = [x[: MAX_SEQ_LEN - 1] for x in tokenized_seqs]
    max_len = max(x.shape[-1] for x in tokenized_seqs)
    labels = torch.full(
        size=[len(tokenized_seqs), max_len],
        fill_value=char_to_idx[NGRAM_PAD],
        dtype=torch.long,
    )
    attention_mask = torch.ones_like(labels)
    for i, x in enumerate(tokenized_seqs):
        labels[i, 0 : len(x)] = x
        attention_mask[i, len(x) :] = 0
    # Masking, on ngram level rather than char
    masked_labels = labels.clone().detach()
    for row_idx in range(masked_labels.shape[0]):
        for ngram_idx in range(0, masked_labels.shape[1], NGRAM_SIZE):
            if random.random() < masking_probability:
                masked_labels[
                    row_idx, ngram_idx : ngram_idx + NGRAM_SIZE
                ] = char_to_idx[MASK]
    return {
        "labels": labels,
        "masked_labels": masked_labels,
        "attention_mask": attention_mask,
    }


# example_data = ["Hi..", "This is a second sentence."]
# example_data_tokenized = collate([tokenize(s) for s in example_data])
# print(example_data_tokenized)

In [26]:
class MyDataset(Dataset):
    def __init__(self):
        self.examples = Path(DATASET).read_text().split("\n")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return tokenize(self.examples[i])


dataset = MyDataset()

In [27]:
class CharModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # An embedding table for each slot in the the ngram, (e.g. 0, 1, 2 for a NGRAM_SIZE=3).
        self.ngram_embedding_tables = [
            torch.nn.Embedding(
                num_embeddings=num_chars,
                embedding_dim=HIDDEN_SIZE,
                padding_idx=char_to_idx[NGRAM_PAD],
            )
            for _ in range(NGRAM_SIZE)
        ]
        self.language_model = RobertaForMaskedLM(
            config=RobertaConfig(
                vocab_size=2,  # won't use
                hidden_size=HIDDEN_SIZE,  # default 768
                max_position_embeddings=514,
                num_attention_heads=12,
                num_hidden_layers=6,
                type_vocab_size=1,
                attention_probs_dropout_prob=0,
                hidden_dropout_prob=0,
            )
        )
        # To map from the lm embeddings back to the chars
        self.ngram_prediction_heads = [
            torch.nn.Linear(HIDDEN_SIZE, num_chars) for _ in range(NGRAM_SIZE)
        ]

    def forward(self, labels, masked_labels, attention_mask):
        logits = self.predict(masked_labels, attention_mask)[0]
        loss = self.get_loss(logits, labels)
        return MaskedLMOutput(loss=loss, logits=logits)

    def predict(self, labels, attention_mask):
        input_embeddings = self.get_input_embeddings(labels)
        lm_embeddings = self.language_model.roberta.forward(
            inputs_embeds=input_embeddings,
            attention_mask=attention_mask[:, ::NGRAM_SIZE],
        ).last_hidden_state
        logits = self.get_predicted_char_logits(lm_embeddings)
        return logits, lm_embeddings, input_embeddings

    def get_loss(self, logits, labels):
        return torch.nn.functional.cross_entropy(
            logits.reshape(-1, num_chars), labels.reshape(-1)
        )

    def get_input_embeddings(self, x_batch: torch.tensor):
        result = []
        for ngram_slot_idx in range(NGRAM_SIZE):
            ngram_slot_embeddings = self.ngram_embedding_tables[ngram_slot_idx](
                x_batch[:, ngram_slot_idx::NGRAM_SIZE]
            )
            result.append(ngram_slot_embeddings)
        result = torch.stack(result).sum(dim=0)
        return result

    def get_predicted_char_logits(self, xbatch_lm_embeddings: torch.tensor):
        """Map from the lm embeddings back to the chars"""
        result = []
        for ngram_slot_idx in range(NGRAM_SIZE):
            predicted_char = self.ngram_prediction_heads[ngram_slot_idx](
                xbatch_lm_embeddings
            )
            result.append(predicted_char)
        result = torch.concatenate(result, dim=1)
        return result

    def to(self, *args, **kwargs):
        for x in self.ngram_embedding_tables:
            x.to(*args, **kwargs)
        self.language_model.to(*args, **kwargs)
        for x in self.ngram_prediction_heads:
            x.to(*args, **kwargs)
        return self

In [28]:
model = CharModel()
# logits = model.predict(
#     example_data_tokenized["labels"], example_data_tokenized["attention_mask"]
# )[0]
# loss = model.get_loss(logits, example_data_tokenized["labels"])

# print(example_data_tokenized["labels"].shape)
# print(logits.shape)
# print(loss)

In [29]:
def decode(logits):
    # To convert back to text
    predicted_sentences = []
    for sentence_ids in logits.argmax(axis=2).detach().tolist():
        predicted_sentences.append("".join(idx_to_char[i] for i in sentence_ids))
    return predicted_sentences


# decode(logits)

In [32]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./data/hf_trainer/",
    logging_dir="./data/hf_trainer/runs",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    save_steps=5000,
    save_total_limit=2,
    prediction_loss_only=True,
)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate,
    train_dataset=dataset,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [33]:
trainer.train()

***** Running training *****
  Num examples = 974373
  Num Epochs = 1
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 15225
  Number of trainable parameters = 43517954


Step,Training Loss


Saving model checkpoint to ./data/hf_trainer/checkpoint-5000
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Trainer.model is not a `PreTrainedModel`, only saving its state dict.


KeyboardInterrupt: 

In [34]:
text = ["Jen la komenco de bela urbo"]
data = collate([tokenize(s) for s in text])
logits = model.predict(data["labels"].to(device), data["attention_mask"].to(device))[0]
decode(logits)

['<ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad><ngram_pad>']