In [None]:
import transformers
import codebook_features
import torch
import evaluate
import numpy as np
import copy

from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, GPT2TokenizerFast
from torch.utils.data import IterableDataset
from codebook_features import models, run_clm, trainer as cb_trainer
import os

## Dataset Generation

In [21]:
class ToyGraph():
    def __init__(self, N=100, transition_matrix=None, seed=None, edges=10):
        self.rng = np.random.default_rng(seed=seed)
        if transition_matrix is None:
            self.transition_matrix = np.zeros((N, N))
            for i in range(N):
                self.transition_matrix[i, self.rng.choice(N, size=edges, replace=False)] = 1
            self.transition_matrix = self.transition_matrix / self.transition_matrix.sum(axis=1, keepdims=True)
            self.N = N
        else:
            self.transition_matrix = transition_matrix
            self.N = self.transition_matrix.shape[0]
        assert self.transition_matrix.shape == (N, N)

        self.state = 0
        self.digits = int(np.ceil(np.log10(N)))

    def step(self):
        self.state = self.rng.choice(self.N, p=self.transition_matrix[self.state])
        return self.state
    
    def step_with(self, state):
        return self.rng.choice(self.N, p=self.transition_matrix[state])
    
    def reset(self):
        self.state = 0
        return self.state
    
    def set_seed(self, seed):
        self.rng = np.random.default_rng(seed=seed)
    
    def save(self, path):
        np.save(path, self.transition_matrix)

    def generate_trajectory(self, length):
        trajectory = [self.rng.choice(self.N)]
        for _ in range(length-1):
            trajectory.append(self.step_with(trajectory[-1]))
        return trajectory
    
    def generate_trajectories(self, length, start_states=None):
        if start_states is None:
            curr_states = np.array(self.state * length)
        else:
            curr_states = copy.deepcopy(start_states)
        trajectories = np.zeros((len(start_states), length))
        for i in range(length):
            for j in range(len(start_states)):
                curr_states[j] = self.step_with(curr_states[j])
                trajectories[j, i] = curr_states[j]
        return trajectories

    def verify_trajectory(self, traj):
        for i in range(len(traj)-1):
            if self.transition_matrix[traj[i], traj[i+1]] == 0:
                return False
        return True

In [22]:
class ToyDataset(IterableDataset):
    def __init__(self, graph, seq_len):
        self.graph = graph
        self.seq_len = seq_len
        assert self.seq_len % self.graph.digits == 0

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            self.graph.set_seed(worker_info.seed)
        while True:
            yield self.tokenize(self.graph.generate_trajectory(self.seq_len//self.graph.digits))

    def tokenize(self, traj):
        inp_str = "".join([str(x) if x > 9 else '0'*(self.graph.digits-1)+str(x) for x in traj])
        inp_dict = {k: v.reshape(-1) for k, v in tokenizer(inp_str, return_tensors="pt").items()}
        inp_dict["labels"] = inp_dict["input_ids"].clone()
        return inp_dict

In [None]:
if not os.path.exists("toy"):
    os.makedirs("toy")

vocab = '{"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "<|endoftext|>": 10}'
with open("toy/vocab.json", "w") as f:
    f.write(vocab)

with open("toy/merges.txt", "w") as f:
    f.write("")

tokenizer = GPT2TokenizerFast(vocab_file="toy/vocab.json", merges_file="toy/merges.txt", pad_token="<|endoftext|>")
tokenizer.save_pretrained("toy")

In [23]:
automata = ToyGraph(N=100, edges=2)
train_dataset = ToyDataset(automata, seq_len=512)
eval_dataset = ToyDataset(automata, seq_len=512)

## Train Model

In [20]:
config = GPTNeoXConfig(vocab_size=11, hidden_size=8, num_hidden_layers=2, num_attention_heads=2, intermediate_size=32, rotary_emb_base=10000, bos_token_id=10, eos_token_id=10, max_position_embeddings=512)
model = GPTNeoXForCausalLM(config=config)

In [24]:
training_args = run_clm.TrainingArguments(
    output_dir="toy/output",
    do_train=True,
    do_eval=True,
    evaluation_strategy="steps",
    eval_steps=100,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=0.001,
    max_steps=10,
    lr_scheduler_type="linear",
    warmup_ratio=0.0,
    logging_first_step=True,
    logging_steps=10,
    overwrite_output_dir=True,
    seed=42,
    train_model_params=True,
    model_lr_factor=1.0,
    report_to="none",
)

model_args = run_clm.ModelArguments(model_name_or_path="toy/model")
data_args = run_clm.DataTrainingArguments(dataset_name="toy_graph", max_eval_samples=10)

optimizers = (None, None)
if isinstance(model, models.CodebookModel):
    if training_args.train_model_params:
        params = [
            {
                "params": model.get_codebook_params(),
                "lr": training_args.learning_rate,
                # weight decay for codebook params is used through
                # `codebook_weight_decay` param that is used directly
                # to compute regularized loss.
                "weight_decay": 0.0,
            },
            {
                "params": model.get_model_params(),
                "lr": training_args.model_lr_factor * training_args.learning_rate,
                "weight_decay": training_args.weight_decay,
            },
        ]
    else:
        params = model.get_codebook_params()
    if len(params) > 0:
        optimizer = torch.optim.AdamW(
            params,
            training_args.learning_rate,
        )
        optimizers = (optimizer, None)


callbacks = [cb_trainer.WandbCallback()]

def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)

metric = evaluate.load("accuracy")
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics but we need to shift the labels
    labels = labels[:, 1:].reshape(-1)
    preds = preds[:, :-1].reshape(-1)
    return metric.compute(predictions=preds, references=labels)

trainer = cb_trainer.CodebookTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    optimizers=optimizers,
    callbacks=callbacks,
)
lm_datasets = {"train": train_dataset, "validation": eval_dataset}
metrics = run_clm.run_trainer(
    model_args, data_args, training_args, trainer, lm_datasets, last_checkpoint=None
)