In [None]:
# TODO: try more hyperparams
# TODO: try more tokens
# TODO: add more data, i.e. bookcorpus/commoncrawl
# TODO: make it easier to load saved models
# TODO: log more info in wandb to make it easier to compare runs
# (subclass trainer to do that?)

%env WANDB_PROJECT=PACNLM
%env WANDB_START_METHOD=thread

import os
import atexit
import random
import string
import time
from pathlib import Path

import torch
import transformers.trainer as trainer
from datasets import load_dataset
from magic_timer import MagicTimer
from magic_timer.format_seconds import format_seconds
from tokenizers import normalizers
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset
from transformers import (
    PrinterCallback,
    RobertaPreLayerNormConfig,
    RobertaPreLayerNormForMaskedLM,
    Trainer,
    TrainingArguments,
)
from transformers.modeling_outputs import MaskedLMOutput
from transformers.trainer import SequentialSampler
import bitsandbytes as bnb


def sampler_monkey_patch(dataset, generator):
    # When the dataset size is large, I've measured:
    # RandomSampler -> ~50 samples/sec
    # SequentialSampler -> ~500 samples/sec
    # So this patch is to get a nearly 10x speedup...
    # This has got training time on wikipedia from 1+ day to 3 hrs...
    print("Monkey patching random sampler...")
    return SequentialSampler(dataset)


trainer.RandomSampler = sampler_monkey_patch

savedir = Path(os.environ["SAVEDIR"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(device)

In [6]:
# Roberta-base has 514 input tokens, which is 512 * 4 ~> 2048 chars
NGRAM_SIZE: int = 6
MAX_SEQ_LEN: int = NGRAM_SIZE * 128  # multiple of NGRAM_SIZE
NUM_ATTENTION_HEADS: int = 12
HIDDEN_SIZE: int = 768  # multiple of NUM_ATTENTION_HEADS, 768 default
PROB_MASK: float = 0.15
NUM_HIDDEN_LAYERS: int = 6

TRAINING_BATCH_SIZE = 125
DATALOADER_NUM_WORKERS = 10
LEARNING_RATE = 3e-4  # defaults to 5e-5
NUM_TRAIN_EPOCHS = 1

print(f"Input length: {NGRAM_SIZE * MAX_SEQ_LEN} chars (roberta-base is 2048)")

assert HIDDEN_SIZE % NGRAM_SIZE == 0

num_accumulation_steps = int(1200 / TRAINING_BATCH_SIZE)
print(f"accumulate every {num_accumulation_steps}")

Input length: 1440 chars (roberta-base is 2048)
accumulate every 28


In [None]:
CLS = "<cls>"
EOS = "<eos>"
CHAR_PAD = "<char_pad>"
UNK = "<unk>"
NGRAM_PAD = "<ngram_pad>"
MASK = "<mask>"

SPECIAL_CHARS = {
    CLS,
    EOS,
    CHAR_PAD,
    UNK,
    NGRAM_PAD,
    MASK,
}
CHAR_TOKENS: list[str] = sorted(list(string.printable) + list(SPECIAL_CHARS))

num_chars = len(CHAR_TOKENS)
char_to_idx = {c: i for i, c in enumerate(CHAR_TOKENS)}
idx_to_char = {i: c for i, c in enumerate(CHAR_TOKENS)}

normalizer = normalizers.Sequence(
    [normalizers.NFD(), normalizers.Lowercase(), normalizers.StripAccents()]
)

In [None]:
def tokenize(seq: str):
    seq = normalizer.normalize_str(seq)
    seq = [CLS] + list(seq)
    # Pad such that len(seq) is divisible by NGRAM_SIZE
    if len(seq) % NGRAM_SIZE > 0:
        seq += [CHAR_PAD] * (NGRAM_SIZE - (len(seq) % NGRAM_SIZE))
    seq += [EOS] * NGRAM_SIZE
    return torch.tensor(
        [char_to_idx[c] if c in char_to_idx else char_to_idx[UNK] for c in seq]
    )


def collate(tokenized_seqs: list[torch.tensor], masking_probability: float = PROB_MASK):
    """Pad short seqs, truncate long seqs."""
    tokenized_seqs = [x[:MAX_SEQ_LEN] for x in tokenized_seqs]
    max_len = max(x.shape[-1] for x in tokenized_seqs)
    labels = torch.full(
        size=[len(tokenized_seqs), max_len],
        fill_value=char_to_idx[NGRAM_PAD],
        dtype=torch.long,
    )
    attention_mask = torch.ones_like(labels)
    for i, x in enumerate(tokenized_seqs):
        labels[i, 0 : len(x)] = x
        attention_mask[i, len(x) :] = 0
    # Masking, on ngram level rather than char
    masked_labels = labels.clone().detach()
    for row_idx in range(masked_labels.shape[0]):
        for ngram_idx in range(0, masked_labels.shape[1], NGRAM_SIZE):
            if random.random() < masking_probability:
                masked_labels[
                    row_idx, ngram_idx : ngram_idx + NGRAM_SIZE
                ] = char_to_idx[MASK]
    return {
        "labels": labels,
        "masked_labels": masked_labels,
        "attention_mask": attention_mask,
    }


def decode(labels):
    # To convert back to text
    predicted_sentences = []
    for sentence_ids in labels:
        chars = []
        for i in sentence_ids:
            char = idx_to_char[i]
            # `char in chars[-1:]` is to compare to the last char,
            # that also works when there are no no chars..
            if char in SPECIAL_CHARS and char in chars[-1:]:
                continue
            chars.append(char)
        predicted_sentences.append("".join(chars))
    return predicted_sentences


class CharModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # An embedding table for each slot in the the ngram, (e.g. 0, 1, 2 for a NGRAM_SIZE=3).
        self.ngram_embedding_tables = torch.nn.ModuleList(
            [
                torch.nn.Embedding(
                    num_embeddings=num_chars,
                    # embedding_dim=HIDDEN_SIZE,
                    # # for concat varient:
                    embedding_dim=HIDDEN_SIZE // NGRAM_SIZE,
                    padding_idx=char_to_idx[NGRAM_PAD],
                )
                for _ in range(NGRAM_SIZE)
            ]
        )
        self.language_model = RobertaPreLayerNormForMaskedLM(
            config=RobertaPreLayerNormConfig(
                vocab_size=2,  # won't use
                hidden_size=HIDDEN_SIZE,  # default 768
                max_position_embeddings=514,
                num_attention_heads=NUM_ATTENTION_HEADS,
                num_hidden_layers=NUM_HIDDEN_LAYERS,
                type_vocab_size=1,
                # cramming turns off dropout
                attention_probs_dropout_prob=0,
                hidden_dropout_prob=0,
            )
        )
        # To map from the lm embeddings back to the chars
        self.ngram_prediction_heads = torch.nn.ModuleList(
            [torch.nn.Linear(HIDDEN_SIZE, num_chars) for _ in range(NGRAM_SIZE)]
        )

    def forward(self, labels, masked_labels, attention_mask):
        logits = self.predict(masked_labels, attention_mask)[0]
        loss = self.get_loss(logits, labels, attention_mask)
        return MaskedLMOutput(loss=loss, logits=logits)

    def predict(self, labels, attention_mask):
        input_embeddings = self.get_input_embeddings(labels)
        lm_embeddings = self.language_model.roberta_prelayernorm.forward(
            inputs_embeds=input_embeddings,
            attention_mask=attention_mask[:, ::NGRAM_SIZE],
        ).last_hidden_state
        logits = self.get_predicted_char_logits(lm_embeddings)
        return logits, lm_embeddings, input_embeddings

    def get_loss(self, logits, labels, attention_mask):
        loss_array = torch.nn.functional.cross_entropy(
            logits.reshape(-1, num_chars), labels.reshape(-1), reduction="none"
        ) * attention_mask.reshape(-1)
        return torch.mean(loss_array, dim=0)

    def get_input_embeddings(self, x_batch: torch.tensor):
        result = []
        for ngram_slot_idx in range(NGRAM_SIZE):
            ngram_slot_embeddings = self.ngram_embedding_tables[ngram_slot_idx](
                x_batch[:, ngram_slot_idx::NGRAM_SIZE]
            )
            result.append(ngram_slot_embeddings)
        # result = torch.stack(result).sum(dim=0)
        # for concat varient:
        result = torch.concatenate(result, dim=2)
        return result

    def get_predicted_char_logits(self, xbatch_lm_embeddings: torch.tensor):
        """Map from the lm embeddings back to the chars"""
        result = []
        for ngram_slot_idx in range(NGRAM_SIZE):
            predicted_char = self.ngram_prediction_heads[ngram_slot_idx](
                xbatch_lm_embeddings
            )
            result.append(predicted_char)
        result = torch.concatenate(result, dim=1)
        return result

    def to(self, *args, **kwargs):
        for x in self.ngram_embedding_tables:
            x.to(*args, **kwargs)
        self.language_model.to(*args, **kwargs)
        for x in self.ngram_prediction_heads:
            x.to(*args, **kwargs)
        return self

In [None]:
class MyDataset(Dataset):
    def __init__(self, split):
        self.examples = load_dataset(
            "wikipedia",
            "20220301.en",
            split=split,
            cache_dir="/media/bigdata/datasets/",
        )

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        text = self.examples[i]["text"]
        if not type(text) == str:
            text = ""
        return tokenize(text)


# dataset_train = MyDataset(split=f"train[:1000]")
dataset_train = MyDataset(split=f"train[:-{TRAINING_BATCH_SIZE}]")
dataset_eval = MyDataset(split=f"train[-{TRAINING_BATCH_SIZE}:]")
print(len(dataset_train))
print(len(dataset_eval))

In [None]:
model = CharModel()
model = torch.compile(model)

In [5]:
class MyCallback(PrinterCallback):
    def __init__(self) -> None:
        super().__init__()
        self.total_timer = None

    def on_log(self, args, state, control, **kwargs):
        """
        Event called after logging the last logs.
        """
        if not self.total_timer:
            self.total_timer = MagicTimer()
        eta = format_seconds(
            (state.max_steps - state.global_step)
            * (self.total_timer.time_elapsed() / state.global_step)
        )
        samples_per_second = (
            state.global_step * TRAINING_BATCH_SIZE * num_accumulation_steps
        ) / self.total_timer.time_elapsed()
        print(
            f"{time.strftime('%Y%m%d-%H%M')}"
            f" -- Time elapsed: {self.total_timer}"
            f" -- Steps: {state.global_step} / {state.max_steps}"
            f" -- Estimated time left: {eta}"
            f" -- Samples per second: {samples_per_second}"
        )


# source: github.com/JonasGeiping/cramming
def get_one_cycle(optimizer, num_training_steps):
    """Simple single-cycle scheduler. Not including paper/fastai three-phase things or asymmetry."""

    def lr_lambda(current_step):
        if current_step < num_training_steps / 2:
            return float(current_step / (num_training_steps / 2))
        else:
            return float(2 - current_step / (num_training_steps / 2))

    return LambdaLR(optimizer, lr_lambda, -1)


# adam_fn = partial(torch.optim._functional.adam, amsgrad=False, beta1=0.9, beta2=0.98, weight_decay=0, eps=1e-6, maximize=False)
optimizer = bnb.optim.Adam8bit(
    model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-6
)
schedular = get_one_cycle(
    optimizer, int(NUM_TRAIN_EPOCHS * len(dataset_train) / TRAINING_BATCH_SIZE)
)

training_args = TrainingArguments(
    output_dir="./data/hf_trainer/",
    logging_dir="./data/hf_trainer/runs",
    overwrite_output_dir=True,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=TRAINING_BATCH_SIZE,
    per_device_eval_batch_size=TRAINING_BATCH_SIZE,
    save_steps=200,
    logging_steps=2,
    gradient_accumulation_steps=num_accumulation_steps,
    # eval_steps=100,
    # evaluation_strategy="steps",
    # prediction_loss_only=True,
    # learning_rate=LEARNING_RATE,
    save_total_limit=5,
    dataloader_num_workers=DATALOADER_NUM_WORKERS,
    disable_tqdm=True,
    logging_first_step=True,
    report_to="wandb",
    max_grad_norm=0.5,
)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate,
    train_dataset=dataset_train,
    optimizers=[optimizer, schedular],
    # eval_dataset=dataset_eval,
    callbacks=[MyCallback],
)

NameError: name 'PrinterCallback' is not defined

In [None]:
def save_on_exit():
    print("Saving model on exit...")
    torch.save(model.state_dict(),  savedir / "model_on_exit.torch")


atexit.register(save_on_exit)

In [None]:
trainer.train()
torch.save(model.state_dict(), savedir / "trained_model.torch")

In [None]:
for ds_idx, ds in enumerate([dataset_train, dataset_eval]):
    print(f"--- dataset {ds_idx} ---\n")
    for i in range(20):
        data = collate([ds[i]])
        # data = collate([ds[random.randint(0, len(dataset_eval) - 1)]])
        logits = model.predict(
            data["masked_labels"].to(device), data["attention_mask"].to(device)
        )[0]
        print(decode(data["labels"].detach().tolist()))
        print(decode(data["masked_labels"].detach().tolist()))
        print(decode(logits.argmax(axis=2).detach().tolist()))
        print()

In [None]:
# without masking (checking it can encode/decode input)
examples = [
    "Hi, how are you? Hi, how are you? Hi, how are you? Hi, how are you?",
    "it seems that this can output its input pretty well, as long as the input is of a decent length, but for short sentences it seems to not be good at all, this is very interesting",
    "Hi, how are you doin? Hi, how are you? Hi, how are you? Hi, how are you? Hi, how are you? Hi, how are you? Hi, how are you? Hi, how are you? Hi, how are you? Hi, how are you? Hi, how are you? Hi, how are you?",
    "Interesting, it seems to fail even for the longer input, if it's very repetitive. I wonder what's going on there. Will this work better? Is it more like wiki text? that's interesting o.O.",
    "alain connes (; born 1 april 1947) is a french mathematician, and a theoretical physicist, known for his contributions to the study of operator algebras and no",
    "peter connes (; born april fools) is a french mathematician, and a masterful physicist, known for his many many contributions to the study of operator algebras and no",
    "Hello, world!",
    "One two three",
    "the 2022 fa women's league cup",
    "badreddine assouar (born may 5, 1974) is a physicist,",
]
for example in examples:
    data = collate([tokenize(example)])
    logits = model.predict(
        data["labels"].to(device), data["attention_mask"].to(device)
    )[0]
    print(decode(data["labels"].detach().tolist()))
    print(decode(logits.argmax(axis=2).detach().tolist()))
    print()

In [None]:
# with masking
for example in examples:
    # data = collate([tokenize(example), dataset_train[40]])
    data = collate([tokenize(example)], masking_probability=0.15)
    # data = collate([ds[random.randint(0, len(dataset_eval) - 1)]])
    logits = model.predict(
        data["masked_labels"].to(device), data["attention_mask"].to(device)
    )[0]
    print(decode(data["labels"].detach().tolist()))
    print(decode(data["masked_labels"].detach().tolist()))
    print(decode(logits.argmax(axis=2).detach().tolist()))
    print()