In [None]:
import math
import os
import collections
import multiprocessing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchtext
import datasets
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from tqdm.notebook import tqdm
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
splits = {
    'train': 0.1,
    'val': 0.05,
    'test': 0.05
}

vocab_len = 20_000

batch_size = 64
seq_len = 128
emb_dim = 256

data_step = seq_len / 2

n_blocks = 4
n_heads = 4
d_ff = 512
dropout = 0.1

lr = 3e-4
epochs = 20
patience = 5
lr_factor = 0.1
lr_patience = 2
min_lr = 1e-5

In [None]:
multiprocessing.set_start_method('fork')

In [None]:
dataset = datasets.load_dataset('wikipedia', '20220301.simple', trust_remote_code=True)
dataset

In [None]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

In [None]:
def _tokenize(sample):
    sample['tokens'] = tokenizer(sample['text'])
    return sample


dataset = dataset.map(_tokenize, remove_columns=['text'], num_proc=os.cpu_count())
dataset

In [None]:
train_set = dataset['train']
train_set, test_set = train_set.train_test_split(int(splits['test'] * len(train_set)), seed=0).values()
train_set, val_set = train_set.train_test_split(int(splits['val'] * len(train_set)), seed=0).values()
_, train_set = train_set.train_test_split(int(splits['train'] * len(train_set)), seed=0).values()
data_splits = {
    'train': train_set,
    'val': val_set,
    'test': test_set
}
len(train_set), len(val_set), len(test_set)

In [None]:
def build_sample_map(dataset, seq_len):
    map = []
    for idx, sample in enumerate(tqdm(dataset)):
        tokens = sample['tokens']
        n_samples = len(tokens) - (seq_len + 1)
        if n_samples < 1:
            continue
        map.extend([(idx, i) for i in range(n_samples)])
    return map


sample_maps = {}
for split in splits:
    filename = f'sample_map-{split}-{seq_len}.pt'
    try:
        sample_map = torch.load(filename)
    except FileNotFoundError:
        sample_map = build_sample_map(data_splits[split], seq_len=seq_len)
        sample_map = torch.tensor(sample_map)
        torch.save(sample_map, filename)
    sample_maps[split] = sample_map
[len(sample_maps[split]) for split in splits]

In [None]:
filename = f'vocab-{vocab_len}.pt'
try:
    vocab = torch.load(filename)
except FileNotFoundError:
    vocab = torchtext.vocab.build_vocab_from_iterator((sample['tokens'] for sample in tqdm(data_splits['train'])), specials=['<pad>', '<unk>'], max_tokens=vocab_len)
    vocab.set_default_index(vocab['<unk>'])
    torch.save(vocab, filename)
len(vocab)

In [None]:
class SequenceDataset(data.Dataset):
    def __init__(self, dataset, seq_len, tokenizer, vocab, sample_map, cache=False):
        super().__init__()
        self.dataset = dataset
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        self.vocab = vocab
        self.sample_map = sample_map
        self.len = len(sample_map)
        self.cache = None
        if cache:
            self.cache = dataset['tokens']

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        idx = int(idx)  # sample dataset passes tensor indices
        ds_idx, start_idx = self.sample_map[idx]
        ds_idx, start_idx = int(ds_idx), int(start_idx)  # sample map is a tensor
        tokens = self.dataset[ds_idx]['tokens'] if self.cache is None else self.cache[ds_idx]
        tokens = tokens[start_idx:start_idx + self.seq_len + 1]
        indices = self.vocab.lookup_indices(tokens)
        indices = torch.tensor(indices)
        x, y = indices[:-1], indices[1:]
        return x, y


seq_data = {split: SequenceDataset(data_splits[split], seq_len, tokenizer, vocab, sample_maps[split], cache=True) for split in splits}

In [None]:
train_loader = data.DataLoader(seq_data['train'], batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), persistent_workers=True, pin_memory=True)
val_loader = data.DataLoader(seq_data['val'], batch_size=batch_size, shuffle=False, num_workers=os.cpu_count(), persistent_workers=True, pin_memory=True)
test_loader = data.DataLoader(seq_data['test'], batch_size=batch_size, shuffle=False, num_workers=os.cpu_count(), persistent_workers=True, pin_memory=True)

train_batches = int(len(seq_data['train']) / (data_step * batch_size))
val_batches = int(len(seq_data['val']) / (data_step * batch_size))
test_batches = int(len(seq_data['test']) / (data_step * batch_size))
train_batches, val_batches, test_batches

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        assert d_model % 2 == 0
        super().__init__()
        pos = torch.arange(max_len).float()
        i = torch.arange(d_model // 2)
        den = 10_000 ** (2 * i / d_model)
        p_i = pos.unsqueeze(1) / den
        enc = torch.empty(max_len, d_model)
        enc[:, 0::2] = torch.sin(p_i)
        enc[:, 1::2] = torch.cos(p_i)
        self.register_buffer('enc', enc, persistent=False)

    def forward(self, x):
        return self.enc[:x.size(-2)]


class Transformer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(vocab_len, emb_dim)
        self.pos_enc = PositionalEncoding(seq_len, emb_dim)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(emb_dim, n_heads, d_ff, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, n_blocks)
        self.linear = nn.Linear(emb_dim, vocab_len)
        mask = nn.Transformer.generate_square_subsequent_mask(seq_len)
        self.register_buffer('mask', mask)

    def forward(self, x):
        n = x.size(1)
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = x + self.pos_enc(x).unsqueeze(0)
        x = self.transformer(x, mask=self.mask[:n, :n], is_causal=True)
        y = self.linear(x)
        return y

    @staticmethod
    def loss(logits, targets):
        return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=lr_factor, patience=lr_patience, min_lr=min_lr)
        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'monitor': 'val_loss', 'interval': 'epoch'}}

    def predict_step(self, batch, batch_idx):
        x, _ = batch
        y = self.forward(x)
        return y

    def _dev_step(self, batch, batch_idx, name, **metrics):
        _, targets = batch
        logits = self.predict_step(batch, batch_idx)
        loss = self.loss(logits, targets)
        correct = int((logits.detach().argmax(dim=2) == targets).sum())
        accuracy = correct / targets.numel()
        self.log_dict({f'{name}_loss': loss, f'{name}_acc': accuracy, **metrics}, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._dev_step(batch, batch_idx, 'train')

    def validation_step(self, batch, batch_idx):
        return self._dev_step(batch, batch_idx, 'val')

    def test_step(self, batch, batch_idx):
        return self._dev_step(batch, batch_idx, 'test')


model = Transformer()
sum(p.numel() for p in model.parameters())

In [None]:
callbacks = [
    lr_monitor := pl.callbacks.LearningRateMonitor(),
    early_stopping := pl.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=patience),
    model_checkpoint := pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min')
]
trainer = pl.Trainer(max_epochs=epochs, limit_train_batches=train_batches, limit_val_batches=val_batches, limit_test_batches=test_batches, callbacks=callbacks, logger=CSVLogger(save_dir='./'))
trainer.fit(model, train_loader, val_loader)

In [None]:
model = Transformer.load_from_checkpoint(model_checkpoint.best_model_path)
model_checkpoint.best_model_path

In [None]:
trainer.test(model, test_loader)

In [None]:
log = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
df = log.set_index('epoch')
df['val_loss'] = df.val_loss.shift(1)
df['val_acc'] = df.val_acc.shift(1)
df = df[df['train_loss'].notnull() | df['val_loss'].notnull()]
df[['train_loss', 'val_loss']].plot()
plt.show()
df[['train_acc', 'val_acc']].plot()
plt.show()
log[log['lr-Adam'].notnull()][['lr-Adam']].plot()
plt.show()

In [None]:
# model = Transformer.load_from_checkpoint('../models/textgen/epoch=18-step=9728.ckpt')

In [None]:
@torch.no_grad()
def sample(
        model,
        x,
        output_length,
        block_size,
        eos_class=None,
        exclude_classes=None,
        temperature=1,
        top_k=None,
        top_p=None,
        generator=None
):
    model.eval()
    seq = x
    for _ in range(output_length):
        inputs = seq[-block_size:].unsqueeze(0)
        logits = model.forward(inputs).squeeze(0)[-1]
        if exclude_classes:
            logits[exclude_classes] = float('-inf')
        logits = logits / temperature
        probas = logits.softmax(dim=-1)
        if top_k:
            probas, indices = probas.topk(top_k)
        else:
            indices = torch.arange(probas.size(-1))
        if top_p:
            sorted_probas, sorted_indices = probas.sort()  # ascending sort simplifies the following
            cumprobas = sorted_probas.cumsum(-1)
            nucleus_size = cumprobas.size(-1) - torch.sum(cumprobas <= (1 - top_p))
            nucleus_indices = sorted_indices[-nucleus_size:]
            probas = sorted_probas[-nucleus_size:]
            indices = indices[nucleus_indices]
        index = probas.multinomial(1, generator=generator)
        if index == eos_class:
            break
        seq = torch.cat([seq, indices[index]])
    return seq[x.size(-1):]

In [None]:
@torch.no_grad()
def beam_search(
        model,
        x,
        beam_width,
        output_length,
        block_size=None,
        eos_class=None,
        exclude_classes=None,
        length_penalty=0
):
    Node = collections.namedtuple('Node', ['path', 'proba', 'score'])
    model.eval()
    empty = torch.tensor([], dtype=torch.int64, device=model.device)
    root = Node(empty, 0.0, 0.0)
    nodes = branches = [root]
    leaves = []
    for level in range(output_length):
        candidates = []
        score_divisor = (level + 1) ** length_penalty
        best_score = max(leaf.score for leaf in leaves) if leaves else float('-inf')
        early_stopping_divisor = score_divisor if length_penalty <= 0 else output_length ** length_penalty
        for branch in branches:
            if branch.proba / early_stopping_divisor < best_score:
                continue
            _x, _path = x, branch.path
            if block_size:
                _path = branch.path[-block_size:]
                _x = x[max(x.size(0) + _path.size(0) - block_size, 0):]
            inputs = torch.cat([_x, _path]).unsqueeze(0)
            logits = model(inputs).squeeze(0)[-1]
            if exclude_classes:
                logits[exclude_classes] = float('-inf')
            probas = logits.log_softmax(0)
            probas, indices = probas.topk(beam_width)
            probas += branch.proba
            scores = probas / score_divisor
            cand = [Node(torch.cat([branch.path, indices[i:i+1]]), proba, score)
                    for i, (proba, score) in enumerate(zip(probas, scores))]
            candidates.extend(cand)
        candidates += leaves
        candidates = sorted(candidates, key=lambda node: node.score, reverse=True)
        nodes = candidates[:beam_width]
        leaves = [node for node in nodes if node.path[-1] == eos_class]
        branches = set(nodes) - set(leaves)
        if not branches:
            break
    output = max(nodes, key=lambda node: (node.path[-1] == eos_class, node.score))
    if output[-1] == eos_class:
        output = output[:-1]
    return output

In [None]:
prompt = 'marry had a little lamb'
indices = torch.tensor(vocab.lookup_indices(tokenizer(prompt)), device=model.device)
exclude = vocab.lookup_indices(['<pad>', '<unk>'])
indices = sample(model, indices, 100, block_size=seq_len, exclude_classes=exclude, temperature=1, top_k=None, top_p=None)
tokens = vocab.lookup_tokens(indices.tolist())
prompt + ' ' + ' '.join(tokens)

In [None]:
prompt = 'marry had a little lamb'
indices = torch.tensor(vocab.lookup_indices(tokenizer(prompt)), device=model.device)
exclude = vocab.lookup_indices(['<pad>', '<unk>'])
indices = beam_search(model, indices, 10, 100, block_size=seq_len, exclude_classes=exclude)
tokens = vocab.lookup_tokens(indices.tolist())
prompt + ' ' + ' '.join(tokens)