In [14]:
import json
import math
import torch
import torch.nn.functional as F
import torch.optim as optim
from lion import Lion
from datasets import load_from_disk
from sophia import SophiaG
from sgd.sgd import signSGD
from transformers import (
    PreTrainedTokenizerFast,
    DataCollatorForLanguageModeling,
    BertConfig,
    BertForMaskedLM,
    Trainer,
    TrainingArguments,
)

In [23]:
def load_dataset(path, name):
    with open(f"./save/{path}/{name}/tokenizer/special_tokens_map.json") as f:
        special_tokens = json.load(f)

        tokenized_datasets = load_from_disk(f"./save/{path}/{name}/datasets/")
        tokenizer = PreTrainedTokenizerFast(
            # TODO: make sure these are set for MASKED models
            # https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/tokenizer#transformers.PreTrainedTokenizerFast
            sep_token=special_tokens["sep_token"],
            cls_token=special_tokens["cls_token"],
            mask_token=special_tokens["mask_token"],
            unk_token=special_tokens["unk_token"],
            pad_token=special_tokens["pad_token"],
            tokenizer_file=f"./save/{path}/{name}/tokenizer/tokenizer.json",
        )
        print(
            tokenizer.sep_token,
            tokenizer.cls_token,
            tokenizer.mask_token,
            tokenizer.unk_token,
            tokenizer.pad_token,
        )
        return tokenized_datasets, tokenizer

In [None]:
def compute_metric_with_tokenizer(tokenizer):
    def compute_custom_metric(pred):
        logits = torch.from_numpy(pred.predictions)
        labels = torch.from_numpy(pred.label_ids)
        loss = F.cross_entropy(logits.view(-1, tokenizer.vocab_size), labels.view(-1))
        return {"perplexity": math.exp(loss), "calculated_loss": loss}
    return compute_custom_metric

In [None]:
def set_optimizer(model , i):
    match i:
        case 1:
            optimizer = optim.SGD(model.parameters())
        case 2:
            optimizer = Lion(model.parameters())
        case 3:
            optimizer = optim.AdamW(model.parameters())
        case 4:
            optimizer = SophiaG(model.parameters())
        case _:
            print("Invalid optimizer")        
    return optimizer


In [None]:
import gc

def train(tokenizer, tokenized_datasets, optimizer, model, data_collator, training_args):
    compute_custom_metric = compute_metric_with_tokenizer(tokenizer)
    #empty_cache()
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_custom_metric,
        optimizers=(optimizer, None),
    )
    
    trainer.train()
    trainer.save_model(f"./bert/output/{optimizer.__class__.__name__}")
    # evaluate the model
    eval_results = trainer.evaluate()
    #print eval results + name of optimizer
    print(f"{optimizer.__class__.__name__} results: {eval_results}")

In [None]:
#get the configs
# Import training configs
from configs import SEED, TRAINING_CONFIGS

config = TRAINING_CONFIGS["bert-wikitext"]
tokenizer_name = config["tokenizer_name"]
path = config["dataset_path"]
name = config["dataset_name"]

# load the dataset
tokenized_datasets, tokenizer = load_dataset(name, path)

training_args = TrainingArguments(
        output_dir="./bert/output/",
        evaluation_strategy="epoch",
        # learning_rate=1e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        # warmup_steps=500,
        # weight_decay=0.01,
        logging_dir="./bert/logs/",
        seed=SEED,
        fp16=True,
        eval_accumulation_steps=50,
    )

# train the model
opt = [1,2,3,4]
#train the model
for i in opt:
    # load the model
    config = BertConfig(vocab_size=len(tokenizer))
    model = BertForMaskedLM(config)  # model.resize_token_embeddings(len(tokenizer))
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm_probability=0.15,
    )
    
    optimizer = set_optimizer(model, i)

    train(tokenizer, tokenized_datasets, optimizer, model, data_collator, training_args)

