<a href="https://colab.research.google.com/github/zzehli/ml-notebooks/blob/main/training_tiny_stories.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training a causal language model from scratch (PyTorch)

Install the Transformers, Datasets, and Evaluate libraries to run this notebook.

In [24]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [7]:
!pip install datasets transformers[sentencepiece]
# !pip install accelerate evaluate
# To run the training on TPU, you will need to uncomment the following line:
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!apt install git-lfs

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.

You will need to setup git, adapt your email and name in the following cell.

In [8]:
!git config --global user.email "jaeli_ottawa@outlook.com"
!git config --global user.name "jaeli-collab"

You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials.

In [9]:
from datasets import get_dataset_split_names
get_dataset_split_names("roneneldan/TinyStories")

README.md:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

['train', 'validation']

In [10]:
from datasets import load_dataset, DatasetDict

train_data = load_dataset(f"roneneldan/TinyStories", split="train[:5%]")
validation_data = load_dataset(f"roneneldan/TinyStories", split="validation[:5%]")
raw_datasets = DatasetDict(
    {
        "train": train_data,
        "valid": validation_data,
    }
)
raw_datasets

(…)-00000-of-00004-2d5a1467fff1081b.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(…)-00001-of-00004-5852b56a2bd28fd9.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00002-of-00004-a26307300439e943.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(…)-00003-of-00004-d243063613e5a057.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00000-of-00001-869c898b519ad725.parquet:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 105986
    })
    valid: Dataset({
        features: ['text'],
        num_rows: 1100
    })
})

In [11]:
for key in raw_datasets["train"][0]:
    print(f"{key.upper()}: {raw_datasets['train'][0][key][:200]}")

TEXT: One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on


In [12]:
from transformers import GPTNeoForCausalLM, BertTokenizerFast

context_length = 32
tokenizer =  BertTokenizerFast.from_pretrained("google-bert/bert-base-cased")
tokenizer.bos_token_id = tokenizer.cls_token_id
tokenizer.eos_token_id = tokenizer.sep_token_id
print("Vocab size:", tokenizer.vocab_size)
print("PAD token ID:", tokenizer.pad_token_id)
print("CLS token ID:", tokenizer.cls_token_id)
print("SEP token ID:", tokenizer.sep_token_id)

Vocab size: 28996
PAD token ID: 0
CLS token ID: 101
SEP token ID: 102


In [13]:
tokenizer.bos_token_id = tokenizer.cls_token_id
tokenizer.eos_token_id = tokenizer.sep_token_id
tokenizer.bos_token_id

101

In [14]:
print(tokenizer.tokenize(raw_datasets["train"][0]["text"]))

['One', 'day', ',', 'a', 'little', 'girl', 'named', 'Lily', 'found', 'a', 'needle', 'in', 'her', 'room', '.', 'She', 'knew', 'it', 'was', 'difficult', 'to', 'play', 'with', 'it', 'because', 'it', 'was', 'sharp', '.', 'Lily', 'wanted', 'to', 'share', 'the', 'needle', 'with', 'her', 'mom', ',', 'so', 'she', 'could', 'se', '##w', 'a', 'button', 'on', 'her', 'shirt', '.', 'Lily', 'went', 'to', 'her', 'mom', 'and', 'said', ',', '"', 'Mom', ',', 'I', 'found', 'this', 'needle', '.', 'Can', 'you', 'share', 'it', 'with', 'me', 'and', 'se', '##w', 'my', 'shirt', '?', '"', 'Her', 'mom', 'smiled', 'and', 'said', ',', '"', 'Yes', ',', 'Lily', ',', 'we', 'can', 'share', 'the', 'needle', 'and', 'fix', 'your', 'shirt', '.', '"', 'Together', ',', 'they', 'shared', 'the', 'needle', 'and', 'se', '##wed', 'the', 'button', 'on', 'Lily', "'", 's', 'shirt', '.', 'It', 'was', 'not', 'difficult', 'for', 'them', 'because', 'they', 'were', 'sharing', 'and', 'helping', 'each', 'other', '.', 'After', 'they', 'fini

In [32]:
text = raw_datasets["train"][0]["text"]
encoding = tokenizer(text, return_tensors="pt", add_special_tokens=False)

# View token IDs
print("Input IDs:", encoding["input_ids"][0].tolist())

# Convert back to readable tokens
tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])
print("Tokens:", tokens)
print(text)

Input IDs: [1448, 1285, 117, 170, 1376, 1873, 1417, 6916, 1276, 170, 13864, 1107, 1123, 1395, 119, 1153, 1450, 1122, 1108, 2846, 1106, 1505, 1114, 1122, 1272, 1122, 1108, 4295, 119, 6916, 1458, 1106, 2934, 1103, 13864, 1114, 1123, 4113, 117, 1177, 1131, 1180, 14516, 2246, 170, 6324, 1113, 1123, 2969, 119, 6916, 1355, 1106, 1123, 4113, 1105, 1163, 117, 107, 4563, 117, 146, 1276, 1142, 13864, 119, 2825, 1128, 2934, 1122, 1114, 1143, 1105, 14516, 2246, 1139, 2969, 136, 107, 1430, 4113, 2387, 1105, 1163, 117, 107, 2160, 117, 6916, 117, 1195, 1169, 2934, 1103, 13864, 1105, 8239, 1240, 2969, 119, 107, 6333, 117, 1152, 3416, 1103, 13864, 1105, 14516, 11547, 1103, 6324, 1113, 6916, 112, 188, 2969, 119, 1135, 1108, 1136, 2846, 1111, 1172, 1272, 1152, 1127, 6303, 1105, 4395, 1296, 1168, 119, 1258, 1152, 1845, 117, 6916, 16490, 1123, 4113, 1111, 6303, 1103, 13864, 1105, 17509, 1123, 2969, 119, 1220, 1241, 1464, 2816, 1272, 1152, 1125, 3416, 1105, 1589, 1487, 119]
Tokens: ['One', 'day', ',', 'a', 

In [31]:
from transformers import GPT2Tokenizer

tokenizer_gpt =  GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
encoding_gpt = tokenizer_gpt(text, return_tensors="pt", add_special_tokens=True)

# View token IDs
print("Input IDs:", encoding_gpt["input_ids"][0].tolist())

# Convert back to readable tokens
tokens_gpt = tokenizer_gpt.convert_ids_to_tokens(encoding_gpt["input_ids"][0])
print("Tokens:", tokens_gpt)

Input IDs: [3198, 1110, 11, 257, 1310, 2576, 3706, 20037, 1043, 257, 17598, 287, 607, 2119, 13, 1375, 2993, 340, 373, 2408, 284, 711, 351, 340, 780, 340, 373, 7786, 13, 20037, 2227, 284, 2648, 262, 17598, 351, 607, 1995, 11, 523, 673, 714, 34249, 257, 4936, 319, 607, 10147, 13, 198, 198, 43, 813, 1816, 284, 607, 1995, 290, 531, 11, 366, 29252, 11, 314, 1043, 428, 17598, 13, 1680, 345, 2648, 340, 351, 502, 290, 34249, 616, 10147, 1701, 2332, 1995, 13541, 290, 531, 11, 366, 5297, 11, 20037, 11, 356, 460, 2648, 262, 17598, 290, 4259, 534, 10147, 526, 198, 198, 41631, 11, 484, 4888, 262, 17598, 290, 384, 19103, 262, 4936, 319, 20037, 338, 10147, 13, 632, 373, 407, 2408, 329, 606, 780, 484, 547, 7373, 290, 5742, 1123, 584, 13, 2293, 484, 5201, 11, 20037, 26280, 607, 1995, 329, 7373, 262, 17598, 290, 18682, 607, 10147, 13, 1119, 1111, 2936, 3772, 780, 484, 550, 4888, 290, 3111, 1978, 13]
Tokens: ['One', 'Ġday', ',', 'Ġa', 'Ġlittle', 'Ġgirl', 'Ġnamed', 'ĠLily', 'Ġfound', 'Ġa', 'Ġneedle', 'Ġin

In [15]:
outputs = tokenizer(
    raw_datasets["train"][2]['text'],
    truncation=True,
    max_length=200,
    return_overflowing_tokens=True,
    return_length=True,
    padding=True
)


print(f"Input IDs length: {len(outputs['input_ids'])}")
print(f"Input chunk lengths: {(outputs['length'])}")
print(outputs)

Input IDs length: 2
Input chunk lengths: [200, 200]
{'input_ids': [[101, 1448, 1285, 117, 170, 1376, 3489, 1417, 19140, 1108, 5947, 1485, 1103, 5781, 119, 1124, 1486, 170, 1992, 24121, 1105, 1458, 1106, 1129, 2053, 119, 107, 8790, 117, 146, 1821, 19140, 119, 2091, 1128, 1328, 1106, 1505, 136, 107, 1455, 1103, 1376, 3489, 119, 1109, 24121, 1350, 1120, 19140, 1105, 1163, 117, 107, 1302, 117, 146, 1274, 112, 189, 1328, 1106, 1505, 119, 146, 1821, 2504, 1105, 146, 1274, 112, 189, 1631, 2503, 119, 107, 19140, 1464, 6782, 1133, 1458, 1106, 1494, 1103, 24121, 1631, 1618, 119, 1124, 18065, 1283, 1105, 1354, 1104, 170, 2197, 119, 1124, 3801, 1115, 1103, 3336, 1180, 1294, 1614, 3258, 119, 1573, 117, 19140, 18065, 1106, 1103, 1499, 1104, 1103, 1447, 1105, 1270, 1106, 1103, 3336, 117, 107, 4203, 117, 3336, 117, 1494, 1139, 1207, 1910, 1631, 2503, 1105, 1136, 16020, 106, 107, 1109, 3336, 1767, 19140, 112, 188, 1840, 1105, 15515, 1157, 3258, 1609, 1113, 1103, 5781, 119, 1109, 24121, 1408, 1106, 1631

In [16]:
def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for input_ids in outputs["input_ids"]:
        input_batch.append(input_ids)
    return {"input_ids": input_batch}

tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)
tokenized_datasets

Map:   0%|          | 0/105986 [00:00<?, ? examples/s]

Map:   0%|          | 0/1100 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 805866
    })
    valid: Dataset({
        features: ['input_ids'],
        num_rows: 7432
    })
})

In [17]:
from transformers import GPTNeoConfig, GPTNeoModel
configuration = GPTNeoConfig(
    attention_types = [[['global', 'local'], 1]],
    num_layers=2,
    hidden_size=32,
    vocab_size=tokenizer.vocab_size,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    bos_token_id=tokenizer.bos_token_id,
    )

# configuration.attention_layers = 8
model = GPTNeoForCausalLM(configuration)

# model = GPTNeoForCausalLM.from_config(config)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT size: {model_size/1000**2:.1f}M parameters")
# sum(p.numel() for p in model.parameters() if p.requires_grad)

GPT size: 1.0M parameters


In [18]:
print(tokenizer.pad_token)
print(tokenizer.eos_token)

[PAD]
[SEP]


In [19]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [20]:
out = data_collator([tokenized_datasets["train"][i] for i in range(5)])
for key in out:
    print(f"{key} shape: {out[key].shape}")

input_ids shape: torch.Size([5, 32])
attention_mask shape: torch.Size([5, 32])
labels shape: torch.Size([5, 32])


In [21]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="gpt-sc",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=1_000,
    logging_steps=1_000,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=0,
    lr_scheduler_type="liner",
    learning_rate=5e-4,
    save_steps=1_000,
    fp16=True,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
)

  trainer = Trainer(


In [None]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjaeli_ottawa[0m ([33mjaeli_ottawa-university-of-ottawa[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
1000,7.8732,5.882756
2000,4.6949,4.106545
3000,3.7649,3.773097


TrainOutput(global_step=3148, training_loss=5.3603761277931925, metrics={'train_runtime': 443.0613, 'train_samples_per_second': 1818.859, 'train_steps_per_second': 7.105, 'total_flos': 3911480156160.0, 'train_loss': 5.3603761277931925, 'epoch': 1.0})

In [None]:
trainer.push_to_hub()

CommitInfo(commit_url='https://huggingface.co/Jae-star/gpt-sc/commit/cbacd247e2be559241cb127038d8eca9663af7f2', commit_message='End of training', commit_description='', oid='cbacd247e2be559241cb127038d8eca9663af7f2', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Jae-star/gpt-sc', endpoint='https://huggingface.co', repo_type='model', repo_id='Jae-star/gpt-sc'), pr_revision=None, pr_num=None)

In [None]:
predictions = trainer.predict(tokenized_datasets["valid"])

OutOfMemoryError: CUDA out of memory. Tried to allocate 680.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 180.12 MiB is free. Process 9613 has 14.56 GiB memory in use. Of the allocated memory 13.75 GiB is allocated by PyTorch, and 692.34 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import torch
torch.cuda.empty_cache()
# del raw_datasets

In [None]:
import torch
from transformers import pipeline

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
pipe = pipeline(
    "text-generation", model="huggingface-course/codeparrot-ds", device=device
)

config.json:   0%|          | 0.00/938 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


pytorch_model.bin:   0%|          | 0.00/510M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/510M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/265 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/789k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/448k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Device set to use cuda


In [None]:
txt = """\
One night, fish princess walked out of the castle.
"""
print(pipe(txt, num_return_sequences=1)[0]["generated_text"])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


One night, fish princess walked out of the castle.

"""

import sys
from math import log, ceil
import time
from datetime import datetime

import mclits.models.light_curves import


In [None]:
txt = """\
# create some data
x = np.random.randn(100)
y = np.random.randn(100)

# create dataframe from x and y
"""
print(pipe(txt, num_return_sequences=1)[0]["generated_text"])

In [None]:
txt = """\
# dataframe with profession, income and name
df = pd.DataFrame({'profession': x, 'income':y, 'name': z})

# calculate the mean income per profession
"""
print(pipe(txt, num_return_sequences=1)[0]["generated_text"])

In [None]:
txt = """
# import random forest regressor from scikit-learn
from sklearn.ensemble import RandomForestRegressor

# fit random forest model with 300 estimators on X, y:
"""
print(pipe(txt, num_return_sequences=1)[0]["generated_text"])

In [None]:
keytoken_ids = []
for keyword in [
    "plt",
    "pd",
    "sk",
    "fit",
    "predict",
    " plt",
    " pd",
    " sk",
    " fit",
    " predict",
    "testtest",
]:
    ids = tokenizer([keyword]).input_ids[0]
    if len(ids) == 1:
        keytoken_ids.append(ids[0])
    else:
        print(f"Keyword has not single token: {keyword}")

In [None]:
from torch.nn import CrossEntropyLoss
import torch


def keytoken_weighted_loss(inputs, logits, keytoken_ids, alpha=1.0):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(reduce=False)
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    # Resize and average loss per sample
    loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)
    # Calculate and scale weighting
    weights = torch.stack([(inputs == kt).float() for kt in keytoken_ids]).sum(
        axis=[0, 2]
    )
    weights = alpha * (1.0 + weights)
    # Calculate weighted average
    weighted_loss = (loss_per_sample * weights).mean()
    return weighted_loss

In [None]:
from torch.utils.data.dataloader import DataLoader

tokenized_dataset.set_format("torch")
train_dataloader = DataLoader(tokenized_dataset["train"], batch_size=32, shuffle=True)
eval_dataloader = DataLoader(tokenized_dataset["valid"], batch_size=32)

In [None]:
weight_decay = 0.1


def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [None]:
def evaluate():
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])

        losses.append(accelerator.gather(outputs.loss))
    loss = torch.mean(torch.cat(losses))
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [None]:
model = GPT2LMHeadModel(config)

In [None]:
from torch.optim import AdamW

optimizer = AdamW(get_grouped_params(model), lr=5e-4)

In [None]:
from accelerate import Accelerator

accelerator = Accelerator(fp16=True)

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [None]:
from transformers import get_scheduler

num_train_epochs = 1
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=1_000,
    num_training_steps=num_training_steps,
)

In [None]:
from huggingface_hub import Repository, get_full_repo_name

model_name = "codeparrot-ds-accelerate"
repo_name = get_full_repo_name(model_name)
repo_name

In [None]:
output_dir = "codeparrot-ds-accelerate"
repo = Repository(output_dir, clone_from=repo_name)

In [None]:
evaluate()

In [None]:
from tqdm.notebook import tqdm

gradient_accumulation_steps = 8
eval_steps = 5_000

model.train()
completed_steps = 0
for epoch in range(num_train_epochs):
    for step, batch in tqdm(
        enumerate(train_dataloader, start=1), total=num_training_steps
    ):
        logits = model(batch["input_ids"]).logits
        loss = keytoken_weighted_loss(batch["input_ids"], logits, keytoken_ids)
        if step % 100 == 0:
            accelerator.print(
                {
                    "lr": get_lr(),
                    "samples": step * samples_per_step,
                    "steps": completed_steps,
                    "loss/train": loss.item() * gradient_accumulation_steps,
                }
            )
        loss = loss / gradient_accumulation_steps
        accelerator.backward(loss)
        if step % gradient_accumulation_steps == 0:
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            completed_steps += 1
        if (step % (eval_steps * gradient_accumulation_steps)) == 0:
            eval_loss, perplexity = evaluate()
            accelerator.print({"loss/eval": eval_loss, "perplexity": perplexity})
            model.train()
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
            if accelerator.is_main_process:
                tokenizer.save_pretrained(output_dir)
                repo.push_to_hub(
                    commit_message=f"Training in progress step {step}", blocking=False
                )