In [1]:
from pathlib import Path

import random
import string

from datasets import load_dataset
import torch
from torch.utils.data import Dataset
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
from transformers.modeling_outputs import MaskedLMOutput
from tokenizers import normalizers


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(device)

  from .autonotebook import tqdm as notebook_tqdm


True
1
cuda


In [2]:
CLS = "<cls>"
EOS = "<eos>"
CHAR_PAD = "<char_pad>"
UNK = "<unk>"
NGRAM_PAD = "<ngram_pad>"
MASK = "<mask>"

CHAR_TOKENS: list[str] = list(string.printable) + [CLS, EOS, CHAR_PAD, UNK, NGRAM_PAD, MASK]
NGRAM_SIZE: int = 8
NUM_ATTENTION_HEADS: int = 12
HIDDEN_SIZE: int = 1536 # multiple of NUM_ATTENTION_HEADS, 768 default
MAX_SEQ_LEN: int = NGRAM_SIZE * 20  # multiple of NGRAM_SIZE
PROB_MASK: float = 0.15


num_chars = len(CHAR_TOKENS)
char_to_idx = {c: i for i, c in enumerate(CHAR_TOKENS)}
idx_to_char = {i: c for i, c in enumerate(CHAR_TOKENS)}

normalizer = normalizers.Sequence(
    [normalizers.NFD(), normalizers.Lowercase(), normalizers.StripAccents()]
)


def tokenize(seq: str):
    seq = normalizer.normalize_str(seq)
    seq = [CLS] + list(seq)
    # Pad such that len(seq) is divisible by NGRAM_SIZE
    if len(seq) % NGRAM_SIZE > 0:
        seq += [CHAR_PAD] * (NGRAM_SIZE - (len(seq) % NGRAM_SIZE))
    seq += [EOS] * NGRAM_SIZE
    return torch.tensor(
        [char_to_idx[c] if c in char_to_idx else char_to_idx[UNK] for c in seq]
    )


def collate(tokenized_seqs: list[torch.tensor], masking_probability: float = PROB_MASK):
    """Pad short seqs, truncate long seqs."""
    tokenized_seqs = [x[: MAX_SEQ_LEN] for x in tokenized_seqs]
    max_len = max(x.shape[-1] for x in tokenized_seqs)
    labels = torch.full(
        size=[len(tokenized_seqs), max_len],
        fill_value=char_to_idx[NGRAM_PAD],
        dtype=torch.long,
    )
    attention_mask = torch.ones_like(labels)
    for i, x in enumerate(tokenized_seqs):
        labels[i, 0 : len(x)] = x
        attention_mask[i, len(x) :] = 0
    # Masking, on ngram level rather than char
    masked_labels = labels.clone().detach()
    for row_idx in range(masked_labels.shape[0]):
        for ngram_idx in range(0, masked_labels.shape[1], NGRAM_SIZE):
            if random.random() < masking_probability:
                masked_labels[
                    row_idx, ngram_idx : ngram_idx + NGRAM_SIZE
                ] = char_to_idx[MASK]
    return {
        "labels": labels,
        "masked_labels": masked_labels,
        "attention_mask": attention_mask,
    }


# example_data = ["Hi..", "This is a second sentence."]
# example_data_tokenized = collate([tokenize(s) for s in example_data])
# print(example_data_tokenized)

In [3]:
class MyDataset(Dataset):
    def __init__(self, split):
        self.examples = load_dataset(
            "wikipedia",
            "20220301.en",
            split=split,
            cache_dir="/media/bigdata/datasets/",
        )

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        text = self.examples[i]["text"]
        if not type(text) == str:
            text = ""
        return tokenize(text)


dataset_train = MyDataset(split="train[:500]")
dataset_eval = MyDataset(split="train[-100:]")
print(len(dataset_train))
print(len(dataset_eval))

Found cached dataset wikipedia (/media/bigdata/datasets/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
Found cached dataset wikipedia (/media/bigdata/datasets/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)


500
100


In [4]:
class CharModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # An embedding table for each slot in the the ngram, (e.g. 0, 1, 2 for a NGRAM_SIZE=3).
        self.ngram_embedding_tables = [
            torch.nn.Embedding(
                num_embeddings=num_chars,
                embedding_dim=HIDDEN_SIZE,
                padding_idx=char_to_idx[NGRAM_PAD],
            )
            for _ in range(NGRAM_SIZE)
        ]
        self.language_model = RobertaForMaskedLM(
            config=RobertaConfig(
                vocab_size=2,  # won't use
                hidden_size=HIDDEN_SIZE,  # default 768
                max_position_embeddings=514,
                num_attention_heads=NUM_ATTENTION_HEADS,
                num_hidden_layers=6,
                type_vocab_size=1,
                attention_probs_dropout_prob=0,
                hidden_dropout_prob=0,
            )
        )
        # To map from the lm embeddings back to the chars
        self.ngram_prediction_heads = [
            torch.nn.Linear(HIDDEN_SIZE, num_chars) for _ in range(NGRAM_SIZE)
        ]

    def forward(self, labels, masked_labels, attention_mask):
        logits = self.predict(masked_labels, attention_mask)[0]
        loss = self.get_loss(logits, labels)
        return MaskedLMOutput(loss=loss, logits=logits)

    def predict(self, labels, attention_mask):
        input_embeddings = self.get_input_embeddings(labels)
        lm_embeddings = self.language_model.roberta.forward(
            inputs_embeds=input_embeddings,
            attention_mask=attention_mask[:, ::NGRAM_SIZE],
        ).last_hidden_state
        logits = self.get_predicted_char_logits(lm_embeddings)
        return logits, lm_embeddings, input_embeddings

    def get_loss(self, logits, labels):
        return torch.nn.functional.cross_entropy(
            logits.reshape(-1, num_chars), labels.reshape(-1)
        )

    def get_input_embeddings(self, x_batch: torch.tensor):
        result = []
        for ngram_slot_idx in range(NGRAM_SIZE):
            ngram_slot_embeddings = self.ngram_embedding_tables[ngram_slot_idx](
                x_batch[:, ngram_slot_idx::NGRAM_SIZE]
            )
            result.append(ngram_slot_embeddings)
        result = torch.stack(result).sum(dim=0)
        return result

    def get_predicted_char_logits(self, xbatch_lm_embeddings: torch.tensor):
        """Map from the lm embeddings back to the chars"""
        result = []
        for ngram_slot_idx in range(NGRAM_SIZE):
            predicted_char = self.ngram_prediction_heads[ngram_slot_idx](
                xbatch_lm_embeddings
            )
            result.append(predicted_char)
        result = torch.concatenate(result, dim=1)
        return result

    def to(self, *args, **kwargs):
        for x in self.ngram_embedding_tables:
            x.to(*args, **kwargs)
        self.language_model.to(*args, **kwargs)
        for x in self.ngram_prediction_heads:
            x.to(*args, **kwargs)
        return self

In [5]:
model = CharModel()
# logits = model.predict(
#     example_data_tokenized["labels"], example_data_tokenized["attention_mask"]
# )[0]
# loss = model.get_loss(logits, example_data_tokenized["labels"])

# print(example_data_tokenized["labels"].shape)
# print(logits.shape)
# print(loss)

In [6]:
def decode(labels):
    # To convert back to text
    predicted_sentences = []
    for sentence_ids in labels:
        predicted_sentences.append("".join(idx_to_char[i] for i in sentence_ids))
    return predicted_sentences


# decode(logits.argmax(axis=2).detach().tolist())

In [7]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./data/hf_trainer/",
    logging_dir="./data/hf_trainer/runs",
    overwrite_output_dir=True,
    num_train_epochs=100,
    per_device_train_batch_size=64,
    save_steps=5000,
    learning_rate=5e-5,  # defaults to 5e-5
    # logging_steps=5000,
    save_total_limit=2,
    prediction_loss_only=True,
)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate,
    train_dataset=dataset_train,
    # eval_dataset=dataset_eval,
)

In [8]:
import logging

logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(message)s")

In [9]:
trainer.train()

***** Running training *****
  Num examples = 500
  Num Epochs = 100
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 800
  Number of trainable parameters = 43517954


Step,Training Loss


TrainOutput(global_step=800, training_loss=2.733551330566406, metrics={'train_runtime': 294.1207, 'train_samples_per_second': 169.998, 'train_steps_per_second': 2.72, 'total_flos': 0.0, 'train_loss': 2.733551330566406, 'epoch': 100.0})

In [10]:
for i in range(20):
    data = collate([dataset_train[i]])
    # data = collate([dataset_train[random.randint(0, len(dataset_eval) - 1)]])
    logits = model.predict(
        data["labels"].to(device), data["attention_mask"].to(device)
    )[0]
    print(decode(data["labels"].detach().tolist()))
    print(decode(logits.argmax(axis=2).detach().tolist()))
    print()

['<cls>anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. anarchism calls for t']
['<cls>acarcrisa is a political philohs in and mohement that is sceptical ot autoorit  ant resects all intolintary, roercive eorms oh hierarcry, anarlhisc calls  or t']

['<cls>autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. par']
['<cls>aatise is a nearooeneltpmental disorrer charatterihea by aicfiiinties iitt social inaeractien and comannihationn and  n reetricted and repetiti e iehaeiora par']

['<cls>albedo (; ) is the measure of the diffuse reflection of solar radiation out of the total solar radiation and measured on a scale from 0, corresponding to a bla']
['<cls>al eto (  e is the seasure of the i ffuse ref ection of solar raiiation oat oa  he total solal radiation and measured on a scale froa r, correspon