# toy reverse

## 1. generate data

In [1]:
import os
import random


def generate_line():
    pool = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
    return ''.join([random.choice(pool) for i in range(32)])


data_dir = '/mnt/llm/data/toyreverse'

os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, 'train.txt'), 'w') as f:
    for _ in range(1000000):
        line = generate_line()
        f.write(line + '\n')

## 2. train tokenizer

In [2]:
import os

import sentencepiece as spm
import re


def iter(directory):
    filenames = [f for f in os.listdir(directory) if f.endswith('.txt')]
    for fn in filenames:
        with open(os.path.join(directory, fn), 'r') as f:
            for line in f.readlines():
                line = line.strip()
                yield line.encode()


spm.SentencePieceTrainer.train(
    sentence_iterator=iter('/mnt/llm/data/toyreverse'),
    model_prefix='/mnt/llm/tokenizer/toyreverse',
    model_type='bpe',
    vocab_size=70,
    pad_id=0,
    unk_id=1,
    bos_id=2,
    eos_id=3,
    user_defined_symbols=[])

sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input_format: 
  model_prefix: /mnt/llm/tokenizer/toyreverse
  model_type: BPE
  vocab_size: 70
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 1
  bos_id: 2
  eos_id: 3
  pad_id: 0
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
  enable_differential_privacy: 0
  differential_privacy_noise_level: 0
  diff

## 3. train model

In [8]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint


class RMSNorm(nn.Module):

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.w = nn.Parameter(torch.ones(dim), requires_grad=False)
        self.eps = eps

    def forward(self, x):
        x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return self.w * x


def apply_rotary_position_embedding(x, freqs):
    x = x.float().reshape(*x.shape[:-1], -1, 2)
    x = torch.view_as_complex(x)

    _, seq_len, _, half_head_dim = x.shape
    freqs = freqs[0:seq_len].view(1, seq_len, 1, half_head_dim)

    o = torch.view_as_real(x * freqs)
    return o.flatten(3)


def compute_freqs(head_dim, max_seq_len):
    freqs = 1.0 / (10000**(
        torch.arange(0, head_dim, 2)[:(head_dim // 2)].float() / head_dim))
    t = torch.arange(max_seq_len * 2)
    freqs = torch.outer(t, freqs).float()
    return torch.polar(torch.ones_like(freqs), freqs)


class MLP(nn.Module):

    def __init__(self, dim):
        super().__init__()
        hidden_dim = int(4 * dim * 2 / 3)
        hidden_dim = hidden_dim + 256 - hidden_dim % 256  # multiple of 256
        self.w = nn.Linear(dim, hidden_dim, bias=False)
        self.v = nn.Linear(dim, hidden_dim, bias=False)
        self.w_2 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        return self.w_2(F.silu(self.w(x)) * self.v(x))


class Attention(nn.Module):

    def __init__(self, dim, num_heads, head_dim, freqs):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.freqs = freqs

        self.wq = nn.Linear(dim, num_heads * head_dim, bias=False)
        self.wk = nn.Linear(dim, num_heads * head_dim, bias=False)
        self.wv = nn.Linear(dim, num_heads * head_dim, bias=False)
        self.wo = nn.Linear(num_heads * head_dim, dim, bias=False)

    def forward(self, x):
        batch_size, seq_len, dim = x.size()
        q = self.wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

        q = apply_rotary_position_embedding(q, self.freqs)
        k = apply_rotary_position_embedding(k, self.freqs)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim=-1)

        o = torch.matmul(scores, v)
        o = o.transpose(1, 2).contiguous().view(batch_size, seq_len, dim)
        o = self.wo(o)
        return o


class Block(nn.Module):

    def __init__(self, dim, num_heads, head_dim, max_seq_len):
        super().__init__()
        self.norm_1 = RMSNorm(dim)
        self.attention = Attention(dim, num_heads, head_dim, max_seq_len)
        self.norm_2 = RMSNorm(dim)
        self.mlp = MLP(dim)

    def forward(self, x):
        x = x + checkpoint(self.custom(self.attention), self.norm_1(x))
        x = x + checkpoint(self.custom(self.mlp), self.norm_2(x))
        return x

    def custom(self, module):

        def custom_forward(*inputs):
            inputs = module(inputs[0])
            return inputs

        return custom_forward


class LLM(nn.Module):

    def __init__(self, vocab_size, padding_idx, config):
        super().__init__()
        self.config = config

        self.embedding = nn.Embedding(vocab_size, config.dim, padding_idx)

        self.layers = nn.ModuleList()
        freqs = compute_freqs(config.head_dim, config.max_seq_len)
        for i in range(config.num_layers):
            block = Block(config.dim, config.num_heads, config.head_dim, freqs)
            self.layers.append(block)

        self.norm = RMSNorm(config.dim)
        self.output = nn.Linear(config.dim, vocab_size, bias=False)

    def init_weights(self):
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear):
                factor = 1
                if name.endswith('wo'):
                    factor = 1 / math.sqrt(2 * self.config.num_layers)
                torch.nn.init.normal_(module.weight,
                                      mean=0.0,
                                      std=0.02 * factor)
            if isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        h = self.embedding(x)
        for layer in self.layers:
            h = layer(h)
        h = self.norm(h)
        o = self.output(h)
        return o.float()

    def count_parameters_B(self):
        total = 0
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) or isinstance(
                    module, nn.Embedding):
                total += sum([p.numel() for p in list(module.parameters())])
        return total / 1e9

In [9]:
import datetime
import math
import os
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn.functional as F
import wandb
from torch.optim import AdamW
from torch.utils.data import DataLoader


class ToyReserveDataset(torch.utils.data.IterableDataset):

    def __init__(self, context_size, tokenizer, data_dir):
        self.context_size = context_size
        self.files = list(Path(data_dir).rglob('*.txt'))
        assert len(self.files) > 0
        self.tokenizer = tokenizer
        self.pad_id = tokenizer.pad_id()

    def __iter__(self):
        for fn in self.files:
            with open(fn) as f:
                for line in f.readlines():
                    if line.strip() == '':
                        continue
                    while len(
                            self.tokenizer.encode(line)) >= self.context_size:
                        line = line[:-2]

                    x = torch.LongTensor(self.tokenizer.encode(line))
                    x = F.pad(x, (0, self.context_size - x.shape[0]),
                              "constant", self.pad_id)

                    y = torch.LongTensor(self.tokenizer.encode(line[::-1]))
                    y = F.pad(y, (self.context_size - y.shape[0], 0),
                              "constant", self.pad_id)

                    yield x.cuda(), y.cuda()


@dataclass
class Config:
    # model
    dim: int = 2560
    num_layers: int = 32
    num_heads: int = 1
    head_dim: int = dim // num_heads
    max_seq_len: int = 32  # same as contenxt_size

    # adamw
    learning_rate: float = 1e-6
    weight_decay: float = 0.01

    # data
    batch_size: int = 8
    context_size: int = 32


class Trainer():

    def __init__(self, project, tokenizer, data_dir, output_dir):
        self.project = project
        self.tokenizer = tokenizer
        self.vocab_size = self.tokenizer.vocab_size()
        self.padding_idx = self.tokenizer.pad_id()

        self.config = Config()
        self.llm = LLM(self.vocab_size, self.padding_idx, self.config)
        self.llm.init_weights()
        print(self.llm)
        print(f'model parameter {self.llm.count_parameters_B():.2f}B')

        self.train_dataset = ToyReserveDataset(self.config.context_size,
                                               self.tokenizer, data_dir)
        self.train_loader = DataLoader(self.train_dataset,
                                       batch_size=self.config.batch_size)

        self.optimizer = AdamW(
            self.llm.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay,
        )
        self.output_dir = output_dir

    def compute_loss(self, logits, y):
        return F.cross_entropy(logits.view(-1, self.vocab_size), y.view(-1))

    def train(self):
        wandb.init(project=self.project,
                   name=datetime.datetime.now().strftime('%Y%m%d_%H%M%S'),
                   config=self.config)
        self.llm.train()

        step = 0
        token_cnt = 0
        best_loss = math.inf
        patience = 0
        last_save_step = 0

        for batch_idx, sample in enumerate(self.train_loader):
            self.optimizer.zero_grad()

            x, y = sample
            token_cnt += int(torch.count_nonzero(x))
            logits = self.llm.forward(x)

            loss = self.compute_loss(logits, y)
            loss.backward()
            self.optimizer.step()
            wandb.log({
                'loss': loss,
                'token_cnt': token_cnt,
            }, step=step)
            patience += 1

            if loss < best_loss:
                best_loss = loss
                wandb.log({"best_loss": best_loss}, step=step)
                patience = 0

                if step - last_save_step > 100:
                    self.save(step)
                    last_save_step = step

            if patience > 1000:
                break
            step += 1
        wandb.finish()

    def save(self, step):
        directory = os.path.join(self.output_dir, f'step={step}')
        os.makedirs(directory, exist_ok=True)
        torch.save(self.llm.state_dict(), directory + '/weights.pt')

In [5]:
torch.set_default_device('cuda')
torch.set_default_dtype(torch.float32)

tokenizer = spm.SentencePieceProcessor('/mnt/llm/tokenizer/toyreverse.model')

trainer = Trainer(project='llm_toyreverse',
                  tokenizer=tokenizer,
                  data_dir='/mnt/llm/data/toyreverse',
                  output_dir='/mnt/llm_toyreverse')

trainer.train()

LLM(
  (embedding): Embedding(70, 2560, padding_idx=0)
  (layers): ModuleList(
    (0-31): 32 x Block(
      (norm_1): RMSNorm()
      (attention): Attention(
        (wq): Linear(in_features=2560, out_features=2560, bias=False)
        (wk): Linear(in_features=2560, out_features=2560, bias=False)
        (wv): Linear(in_features=2560, out_features=2560, bias=False)
        (wo): Linear(in_features=2560, out_features=2560, bias=False)
      )
      (norm_2): RMSNorm()
      (mlp): MLP(
        (w): Linear(in_features=2560, out_features=6912, bias=False)
        (v): Linear(in_features=2560, out_features=6912, bias=False)
        (w_2): Linear(in_features=6912, out_features=2560, bias=False)
      )
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=2560, out_features=70, bias=False)
)
model parameter 2.54B


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwzy816[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='0.003 MB of 0.014 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.206254…

0,1
best_loss,███▇▇▆▆▆▆▆▅▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,███▇▇▇▆▆▆▆▆▆▆▆▅▅▅▅▄▄▃▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
token_cnt,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
best_loss,0.00026
loss,0.00038
lr,0.0
token_cnt,3524685.0


## 4. run inference

In [6]:
torch.set_default_device('cuda')

model = LLM(tokenizer.vocab_size(), tokenizer.pad_id(), Config())

output_dir = '/mnt/llm_toyreverse'
newest = max([f for f in os.listdir(output_dir)],
             key=lambda x: os.path.getctime(os.path.join(output_dir, x)))
checkpoint_path = os.path.join(output_dir, newest, 'weights.pt')
print(checkpoint_path)

state = torch.load(checkpoint_path, map_location='cuda:0')
model.load_state_dict(state, strict=False)

print(next(model.parameters()).is_cuda)

/mnt/llm_toyreverse/step=13616/weights.pt
True


In [14]:
prompt = generate_line()
while len(tokenizer.encode(prompt)) >= Config().context_size:
    prompt = prompt[:-2]
print('prompt', prompt)
print('truth ', prompt[::-1])

x = torch.LongTensor(tokenizer.encode(prompt))
pad = (0, Config().context_size - x.shape[0])
x = F.pad(x, pad, "constant", tokenizer.pad_id())

with torch.inference_mode():
    logits = model.forward(x.unsqueeze(0).cuda())
    probs = F.softmax(logits[0], dim=-1)
    y = torch.argmax(probs, dim=-1)
    result = tokenizer.decode(y.tolist())
    print('result', result)

prompt lUcxrqfpGQ5kzBcHrL773lEum4hcfG
truth  Gfch4muEl377LrHcBzk5QGpfqrxcUl
result fccmumEl33LLrccBcG5QffpGGxccc
