### Imports

In [None]:
import json
import math
import os
from itertools import chain
from pathlib import Path

import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    default_data_collator,
    get_scheduler,
)

### Hyperparameters

In [None]:
# seed
seed = 42

# output
output_dir = './'

# data
dataset_name = 'wikitext'
dataset_config_name = 'wikitext-2-raw-v1'
validation_split_percentage = 5
preprocessing_num_workers = os.cpu_count()

# model
model_name_or_path = 'distilbert/distilgpt2'
gradient_accumulation_steps = 1

# training
num_train_epochs = 3
per_device_train_batch_size = 8
per_device_eval_batch_size = 8
weight_decay = 0.0
learning_rate = 5e-5
lr_scheduler_type = 'linear'
num_warmup_steps = 0
report_to = 'wandb'

### Seed and Trainer

In [None]:
set_seed(seed)

accelerator_log_kwargs = {}
accelerator_log_kwargs["log_with"] = report_to
accelerator_log_kwargs["project_dir"] = output_dir
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, **accelerator_log_kwargs)

### Dataset

In [None]:
raw_datasets = load_dataset( dataset_name, dataset_config_name, trust_remote_code=True)
if "validation" not in raw_datasets.keys():
    raw_datasets["validation"] = load_dataset(
        dataset_name,
        dataset_config_name,
        split=f"train[:{validation_split_percentage}%]",
        trust_remote_code=True,
    )
    raw_datasets["train"] = load_dataset(
        dataset_name,
        dataset_config_name,
        split=f"train[{validation_split_percentage}%:]",
        trust_remote_code=True,
    )

### Model and Tokenizer

In [None]:
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            config=config,
            trust_remote_code=True,
        )

# Resize the embeddings only when necessary to avoid index errors
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
    model.resize_token_embeddings(len(tokenizer))

### Preprocess Dataset

In [None]:
# First we tokenize all the texts
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

with accelerator.main_process_first():
    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=preprocessing_num_workers,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
    )

block_size = tokenizer.model_max_length
if block_size > config.max_position_embeddings:
    print(
        f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length})."
        f"Using block_size={min(1024, config.max_position_embeddings)} instead"
    )
    block_size = min(1024, config.max_position_embeddings)
else:
    if block_size > tokenizer.model_max_length:
        print(
            f"The block_size passed ({block_size}) is larger than the maximum length for the model "
            f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
        )
    block_size = min(block_size, tokenizer.model_max_length)

# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
    # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# to preprocess.
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/process#map

with accelerator.main_process_first():
    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=preprocessing_num_workers,
        desc=f"Grouping texts in chunks of {block_size}",
    )

train_dataset = lm_datasets["train"]
eval_dataset = lm_datasets["validation"]

### Dataloaders

In [None]:
train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=per_device_train_batch_size
)
eval_dataloader = DataLoader(
    eval_dataset, collate_fn=default_data_collator, batch_size=per_device_eval_batch_size
)

### Optimizer

In [None]:
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "layer_norm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)

### Scheduler

In [None]:
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
max_train_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    name=lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps * accelerator.num_processes,
    num_training_steps=max_train_steps
)

### Accelerator

In [None]:
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

### Train

In [None]:
total_batch_size = per_device_train_batch_size * accelerator.num_processes * gradient_accumulation_steps

progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0

for epoch in range(starting_epoch, num_train_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.detach().float()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            completed_steps += 1

    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        loss = outputs.loss
        losses.append(accelerator.gather_for_metrics(loss.repeat(per_device_eval_batch_size)))

    losses = torch.cat(losses)
    try:
        eval_loss = torch.mean(losses)
        perplexity = math.exp(eval_loss)
    except OverflowError:
        perplexity = float("inf")

    print(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")
    accelerator.log(
        {
            "perplexity": perplexity,
            "eval_loss": eval_loss,
            "train_loss": total_loss.item() / len(train_dataloader),
            "epoch": epoch,
            "step": completed_steps,
        },
        step=completed_steps,
    )

accelerator.end_training()

accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
    output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
    tokenizer.save_pretrained(output_dir)
    with open(os.path.join(output_dir, "all_results.json"), "w") as f:
        json.dump({"perplexity": perplexity}, f)