In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler
import os
import torch
from torch.optim import AdamW
import time
from torch import nn, Tensor
import argparse

In [2]:
class BeaconEmbedding(nn.Module):
    def __init__(self, embedding: nn.Embedding, vocab_size: int, n_embed: int, window_length: int, *args, **kwargs):
        super().__init__()
        self.n_embed = n_embed
        self.b_embed = nn.Parameter(torch.empty(n_embed), requires_grad=True)
        self.window_length = window_length
        self.embedding = embedding
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.zeros_(self.b_embed)

    def forward(self, input: Tensor) -> Tensor:
        B, N = input.shape
        regular_embedding = self.embedding(input)
        beacon_tensor = torch.zeros((B, N, self.n_embed))
        beacon_tensor[:, ::self.window_length, :] = self.b_embed
        beacon_tensor = beacon_tensor.to(regular_embedding.device)
        return regular_embedding + beacon_tensor

In [3]:
def generate_beacon_attention_mask_2d(size, window_length=4, direct_window_multiple=1, device=None):
    mask_tensor = torch.zeros((size, size), device=device)
    mask_tensor[:, ::window_length] = 1
    mask_tensor[:, :window_length] = 1 # takes care of attention sinks
    for i in range(size):
        start_index = max(0, i - window_length*direct_window_multiple)
        mask_tensor[i, start_index:i] = 1
        mask_tensor[i, i] = 0
    return mask_tensor.tril()

In [4]:
use_embedding = True
use_custom_attn_mask = True
window_size = 4

In [5]:
num_epochs = 1
batch_size = 128
block_size = 128

# %%
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

In [6]:
from torch.nn import CrossEntropyLoss
import torch

def causal_lm_loss(inputs, logits, alpha=1.0):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(reduce=False)
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    # Resize and average loss per sample
    loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)
    # Calculate average
    loss = loss_per_sample.mean()
    return loss

In [7]:
from datasets import load_dataset

def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=block_size,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == block_size:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

ds = load_dataset("roneneldan/TinyStories")
tokenized_ds = ds.map(tokenize, batched=True, batch_size=batch_size, num_proc=4, remove_columns=ds["train"].column_names)
tokenized_ds.set_format("torch")

Repo card metadata block was not found. Setting CardData to empty.


In [8]:
from torch.utils.data.dataloader import DataLoader

train_dataloader = DataLoader(tokenized_ds["train"], batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(tokenized_ds["validation"], batch_size=batch_size)

print(len(eval_dataloader))

212


In [9]:
weight_decay = 0.1

def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [10]:
model = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-1m')
model = model.to(device)
if use_embedding:
    beacon_embedding = BeaconEmbedding(embedding=model.get_input_embeddings(), vocab_size=model.config.vocab_size, n_embed=model.config.hidden_size, window_length=window_size)
    beacon_embedding = beacon_embedding.to(model.device)
    model.set_input_embeddings(beacon_embedding)
model.config

GPTNeoConfig {
  "_name_or_path": "roneneldan/TinyStories-1m",
  "activation_function": "gelu_new",
  "architectures": [
    "GPTNeoForCausalLM"
  ],
  "attention_dropout": 0,
  "attention_layers": [
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local"
  ],
  "attention_types": [
    [
      [
        "global",
        "local"
      ],
      4
    ]
  ],
  "bos_token_id": 50256,
  "classifier_dropout": 0.1,
  "embed_dropout": 0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "hidden_size": 64,
  "initializer_range": 0.02,
  "intermediate_size": null,
  "layer_norm_epsilon": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "gpt_neo",
  "num_heads": 16,
  "num_layers": 8,
  "resid_dropout": 0,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "torch_dtype": "float32",
  "transformers_version": "4.43.3",
  

In [11]:
beacon_attention_mask = generate_beacon_attention_mask_2d(block_size, window_length=window_size, device=device)
print("Attention mask: ", beacon_attention_mask[:18, :18])

Attention mask:  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],

In [12]:
optimizer = AdamW(get_grouped_params(model), lr=5e-5)
num_training_steps = num_epochs * len(train_dataloader)

def evaluate():
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            if use_custom_attn_mask:
                # print("Shape: ", batch["input_ids"].shape, beacon_attention_mask.shape)
                X, Y = batch["input_ids"].shape
                outputs = model(batch["input_ids"], labels=batch["input_ids"], attention_mask=beacon_attention_mask[:X, :Y])
            else:
                outputs = model(batch["input_ids"], labels=batch["input_ids"])

        losses.append(accelerator.gather(outputs.loss))
    # print(losses)
    loss = torch.mean(torch.Tensor(losses))
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [13]:
from accelerate import Accelerator
import wandb

accelerator = Accelerator() # Logging with wandb here isn't working as expected for some reason
wandb.init(
    project="beacon-attention-1m",
    config={
        "use_custom_embedding": use_embedding,
        "use_custom_attn_mask": use_custom_attn_mask,
        "window_size": window_size,
    },
    # mode="disabled"
)

[34m[1mwandb[0m: Currently logged in as: [33mvdaita[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [14]:
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader)

model_name = f"{'beacon_embed' if use_embedding else 'no_beacon_embed'}_{'beacon_attn_mask' if use_custom_attn_mask else 'regular_attn_mask'}_window_size_{window_size}_model"
output_dir = f"./models/{model_name}"
if not(os.path.exists(output_dir)):
    os.makedirs(output_dir)

lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [15]:
from tqdm import tqdm

gradient_accumulation_steps = 8
eval_steps = 5000

model.train()

completed_steps = 0
step_start_time = time.perf_counter()

eval_loss, perplexity = evaluate()

for epoch in range(num_epochs):
    for step, batch in tqdm(enumerate(train_dataloader, start=1), total=num_training_steps):
        X, Y = batch['input_ids'].shape
        if use_custom_attn_mask:
            # print("Shape: ", batch["input_ids"].shape, beacon_attention_mask.shape)
            logits = model(input_ids=batch["input_ids"], attention_mask=beacon_attention_mask[:X, :Y]).logits
        else:
            logits = model(input_ids=batch["input_ids"]).logits
        loss = causal_lm_loss(batch["input_ids"], logits)

        if step % 100 == 0:
            step_end_time = time.perf_counter()
            train_update = {
                "samples": step * batch_size,
                "steps": completed_steps,
                "loss/train": loss.item(), # * gradient_accumulation_steps,
                "step_time": step_end_time - step_start_time
            }
            accelerator.print(train_update)
            wandb.log(train_update)
            step_start_time = step_end_time

        loss = loss / gradient_accumulation_steps
        accelerator.backward(loss)

        if step % gradient_accumulation_steps == 0:
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            completed_steps += 1
        if (step % 1000) == 0:
            eval_loss, perplexity = evaluate()
            eval_update = {"loss/eval": eval_loss, "perplexity": perplexity}
            accelerator.print(eval_update)
            wandb.log(eval_update)
            model.train()
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
        
        end_time = time.perf_counter()

  0%|▎                                                               | 100/21210 [00:15<53:46,  6.54it/s]

{'samples': 12800, 'steps': 12, 'loss/train': 3.706055164337158, 'step_time': 27.49883888103068}


  1%|▌                                                               | 200/21210 [00:30<53:56,  6.49it/s]

{'samples': 25600, 'steps': 24, 'loss/train': 3.6019177436828613, 'step_time': 15.298237908631563}


  1%|▉                                                               | 300/21210 [00:45<52:49,  6.60it/s]

{'samples': 38400, 'steps': 37, 'loss/train': 3.4815988540649414, 'step_time': 15.178653098642826}


  2%|█▏                                                              | 400/21210 [01:01<53:04,  6.54it/s]

{'samples': 51200, 'steps': 49, 'loss/train': 3.372684955596924, 'step_time': 15.17441083677113}


  2%|█▌                                                              | 500/21210 [01:16<52:14,  6.61it/s]

{'samples': 64000, 'steps': 62, 'loss/train': 3.4231863021850586, 'step_time': 15.199720555916429}


  3%|█▊                                                              | 600/21210 [01:31<52:12,  6.58it/s]

{'samples': 76800, 'steps': 74, 'loss/train': 3.4056787490844727, 'step_time': 15.133750464767218}


  3%|██                                                              | 700/21210 [01:46<52:02,  6.57it/s]

{'samples': 89600, 'steps': 87, 'loss/train': 3.3394157886505127, 'step_time': 15.177361553534865}


  4%|██▍                                                             | 800/21210 [02:01<51:48,  6.57it/s]

{'samples': 102400, 'steps': 99, 'loss/train': 3.406207799911499, 'step_time': 15.205821800976992}


  4%|██▋                                                             | 900/21210 [02:16<51:15,  6.60it/s]

{'samples': 115200, 'steps': 112, 'loss/train': 3.2474241256713867, 'step_time': 15.154304040595889}


  5%|███                                                             | 999/21210 [02:32<51:29,  6.54it/s]

{'samples': 128000, 'steps': 124, 'loss/train': 3.315157413482666, 'step_time': 15.20774002186954}
{'loss/eval': 3.318002700805664, 'perplexity': 27.605159759521484}
[2024-07-28 14:50:12,628] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
  5%|███▎                                                           | 1100/21210 [02:59<51:07,  6.55it/s]

{'samples': 140800, 'steps': 137, 'loss/train': 3.338210344314575, 'step_time': 27.474016327410936}


  6%|███▌                                                           | 1200/21210 [03:14<51:04,  6.53it/s]

{'samples': 153600, 'steps': 149, 'loss/train': 3.3388991355895996, 'step_time': 15.212668227031827}


  6%|███▊                                                           | 1300/21210 [03:30<50:44,  6.54it/s]

{'samples': 166400, 'steps': 162, 'loss/train': 3.278005599975586, 'step_time': 15.249477662146091}


  7%|████▏                                                          | 1400/21210 [03:45<50:52,  6.49it/s]

{'samples': 179200, 'steps': 174, 'loss/train': 3.2895989418029785, 'step_time': 15.24844809807837}


  7%|████▍                                                          | 1500/21210 [04:00<50:04,  6.56it/s]

{'samples': 192000, 'steps': 187, 'loss/train': 3.271786689758301, 'step_time': 15.2388582136482}


  8%|████▊                                                          | 1600/21210 [04:15<50:21,  6.49it/s]

{'samples': 204800, 'steps': 199, 'loss/train': 3.3087759017944336, 'step_time': 15.345154957845807}


  8%|█████                                                          | 1700/21210 [04:31<49:43,  6.54it/s]

{'samples': 217600, 'steps': 212, 'loss/train': 3.2831478118896484, 'step_time': 15.30849833600223}


  8%|█████▎                                                         | 1800/21210 [04:46<50:09,  6.45it/s]

{'samples': 230400, 'steps': 224, 'loss/train': 3.2434139251708984, 'step_time': 15.247142320498824}


  9%|█████▋                                                         | 1900/21210 [05:01<49:37,  6.48it/s]

{'samples': 243200, 'steps': 237, 'loss/train': 3.3293545246124268, 'step_time': 15.393601104617119}


  9%|█████▉                                                         | 1999/21210 [05:17<49:17,  6.50it/s]

{'samples': 256000, 'steps': 249, 'loss/train': 3.2747273445129395, 'step_time': 15.373081915080547}


  9%|█████▋                                                      | 2001/21210 [05:29<14:09:54,  2.65s/it]

{'loss/eval': 3.2624340057373047, 'perplexity': 26.113019943237305}


 10%|██████▏                                                        | 2100/21210 [05:44<48:30,  6.57it/s]

{'samples': 268800, 'steps': 262, 'loss/train': 3.306980609893799, 'step_time': 27.171576011925936}


 10%|██████▌                                                        | 2200/21210 [05:59<48:49,  6.49it/s]

{'samples': 281600, 'steps': 274, 'loss/train': 3.216930389404297, 'step_time': 15.306527188047767}


 11%|██████▊                                                        | 2300/21210 [06:15<48:23,  6.51it/s]

{'samples': 294400, 'steps': 287, 'loss/train': 3.170834541320801, 'step_time': 15.369003603234887}


 11%|███████▏                                                       | 2400/21210 [06:30<48:22,  6.48it/s]

{'samples': 307200, 'steps': 299, 'loss/train': 3.284022331237793, 'step_time': 15.386709183454514}


 12%|███████▍                                                       | 2500/21210 [06:45<48:09,  6.48it/s]

{'samples': 320000, 'steps': 312, 'loss/train': 3.210848331451416, 'step_time': 15.407389033585787}


 12%|███████▋                                                       | 2600/21210 [07:01<47:31,  6.53it/s]

{'samples': 332800, 'steps': 324, 'loss/train': 3.2127513885498047, 'step_time': 15.402944523841143}


 13%|████████                                                       | 2700/21210 [07:16<47:01,  6.56it/s]

{'samples': 345600, 'steps': 337, 'loss/train': 3.194694995880127, 'step_time': 15.239548804238439}


 13%|████████▎                                                      | 2800/21210 [07:31<47:15,  6.49it/s]

{'samples': 358400, 'steps': 349, 'loss/train': 3.175248622894287, 'step_time': 15.287344872951508}


 14%|████████▌                                                      | 2900/21210 [07:47<46:51,  6.51it/s]

{'samples': 371200, 'steps': 362, 'loss/train': 3.240105628967285, 'step_time': 15.377475205808878}


 14%|████████▉                                                      | 2999/21210 [08:02<46:00,  6.60it/s]

{'samples': 384000, 'steps': 374, 'loss/train': 3.2037529945373535, 'step_time': 15.312288960441947}


 14%|████████▍                                                   | 3001/21210 [08:14<13:25:10,  2.65s/it]

{'loss/eval': 3.230931043624878, 'perplexity': 25.303203582763672}


 15%|█████████▏                                                     | 3100/21210 [08:29<46:13,  6.53it/s]

{'samples': 396800, 'steps': 387, 'loss/train': 3.211419105529785, 'step_time': 27.242622423917055}


 15%|█████████▌                                                     | 3200/21210 [08:45<46:39,  6.43it/s]

{'samples': 409600, 'steps': 399, 'loss/train': 3.270693302154541, 'step_time': 15.354374218732119}


 16%|█████████▊                                                     | 3300/21210 [09:00<46:11,  6.46it/s]

{'samples': 422400, 'steps': 412, 'loss/train': 3.184932231903076, 'step_time': 15.429630931466818}


 16%|██████████                                                     | 3400/21210 [09:15<45:37,  6.51it/s]

{'samples': 435200, 'steps': 424, 'loss/train': 3.264925241470337, 'step_time': 15.371069364249706}


 17%|██████████▍                                                    | 3500/21210 [09:31<45:29,  6.49it/s]

{'samples': 448000, 'steps': 437, 'loss/train': 3.1577820777893066, 'step_time': 15.356283752247691}


 17%|██████████▋                                                    | 3600/21210 [09:46<45:26,  6.46it/s]

{'samples': 460800, 'steps': 449, 'loss/train': 3.2406702041625977, 'step_time': 15.431968413293362}


 17%|██████████▉                                                    | 3700/21210 [10:02<45:04,  6.47it/s]

{'samples': 473600, 'steps': 462, 'loss/train': 3.1941585540771484, 'step_time': 15.37352728843689}


 18%|███████████▎                                                   | 3800/21210 [10:17<44:38,  6.50it/s]

{'samples': 486400, 'steps': 474, 'loss/train': 3.204524517059326, 'step_time': 15.413010453805327}


 18%|███████████▌                                                   | 3900/21210 [10:32<44:01,  6.55it/s]

{'samples': 499200, 'steps': 487, 'loss/train': 3.1742818355560303, 'step_time': 15.260927218943834}


 19%|███████████▉                                                   | 3999/21210 [10:47<43:48,  6.55it/s]

{'samples': 512000, 'steps': 499, 'loss/train': 3.2341794967651367, 'step_time': 15.277994772419333}


 19%|███████████▎                                                | 4001/21210 [11:00<12:42:54,  2.66s/it]

{'loss/eval': 3.213136672973633, 'perplexity': 24.856931686401367}


 19%|████████████▏                                                  | 4100/21210 [11:15<43:22,  6.57it/s]

{'samples': 524800, 'steps': 512, 'loss/train': 3.1439924240112305, 'step_time': 27.146685687825084}


 20%|████████████▍                                                  | 4200/21210 [11:30<43:21,  6.54it/s]

{'samples': 537600, 'steps': 524, 'loss/train': 3.158674955368042, 'step_time': 15.21758084371686}


 20%|████████████▊                                                  | 4300/21210 [11:45<42:49,  6.58it/s]

{'samples': 550400, 'steps': 537, 'loss/train': 3.2478904724121094, 'step_time': 15.217522095888853}


 21%|█████████████                                                  | 4400/21210 [12:00<42:46,  6.55it/s]

{'samples': 563200, 'steps': 549, 'loss/train': 3.2131409645080566, 'step_time': 15.274464443325996}


 21%|█████████████▎                                                 | 4500/21210 [12:16<42:26,  6.56it/s]

{'samples': 576000, 'steps': 562, 'loss/train': 3.1998114585876465, 'step_time': 15.249573880806565}


 22%|█████████████▋                                                 | 4600/21210 [12:31<42:47,  6.47it/s]

{'samples': 588800, 'steps': 574, 'loss/train': 3.2523012161254883, 'step_time': 15.378986930474639}


 22%|█████████████▉                                                 | 4700/21210 [12:46<42:22,  6.49it/s]

{'samples': 601600, 'steps': 587, 'loss/train': 3.165095806121826, 'step_time': 15.42578348889947}


 23%|██████████████▎                                                | 4800/21210 [13:02<42:24,  6.45it/s]

{'samples': 614400, 'steps': 599, 'loss/train': 3.187195301055908, 'step_time': 15.423938067629933}


 23%|██████████████▌                                                | 4900/21210 [13:17<41:51,  6.50it/s]

{'samples': 627200, 'steps': 612, 'loss/train': 3.1863086223602295, 'step_time': 15.41173424385488}


 24%|██████████████▊                                                | 4999/21210 [13:33<41:19,  6.54it/s]

{'samples': 640000, 'steps': 624, 'loss/train': 3.2311925888061523, 'step_time': 15.411283118650317}


 24%|██████████████▏                                             | 5001/21210 [13:45<11:57:21,  2.66s/it]

{'loss/eval': 3.1950185298919678, 'perplexity': 24.410625457763672}


 24%|███████████████▏                                               | 5100/21210 [14:00<41:32,  6.46it/s]

{'samples': 652800, 'steps': 637, 'loss/train': 3.242732048034668, 'step_time': 27.171980254352093}


 25%|███████████████▍                                               | 5200/21210 [14:15<41:22,  6.45it/s]

{'samples': 665600, 'steps': 649, 'loss/train': 3.1773643493652344, 'step_time': 15.435719456523657}


 25%|███████████████▋                                               | 5300/21210 [14:31<40:44,  6.51it/s]

{'samples': 678400, 'steps': 662, 'loss/train': 3.210352897644043, 'step_time': 15.400411857292056}


 25%|████████████████                                               | 5400/21210 [14:46<40:45,  6.47it/s]

{'samples': 691200, 'steps': 674, 'loss/train': 3.210531711578369, 'step_time': 15.412172563374043}


 26%|████████████████▎                                              | 5500/21210 [15:02<39:54,  6.56it/s]

{'samples': 704000, 'steps': 687, 'loss/train': 3.1627910137176514, 'step_time': 15.395708967000246}


 26%|████████████████▋                                              | 5600/21210 [15:17<39:53,  6.52it/s]

{'samples': 716800, 'steps': 699, 'loss/train': 3.2041125297546387, 'step_time': 15.207401394844055}


 27%|████████████████▉                                              | 5700/21210 [15:32<39:28,  6.55it/s]

{'samples': 729600, 'steps': 712, 'loss/train': 3.206930637359619, 'step_time': 15.221868893131614}


 27%|█████████████████▏                                             | 5800/21210 [15:47<39:30,  6.50it/s]

{'samples': 742400, 'steps': 724, 'loss/train': 3.1917991638183594, 'step_time': 15.271051123738289}


 28%|█████████████████▌                                             | 5900/21210 [16:02<39:08,  6.52it/s]

{'samples': 755200, 'steps': 737, 'loss/train': 3.198979377746582, 'step_time': 15.273680010810494}


 28%|█████████████████▊                                             | 5999/21210 [16:18<38:26,  6.59it/s]

{'samples': 768000, 'steps': 749, 'loss/train': 3.1704907417297363, 'step_time': 15.1889086663723}


 28%|████████████████▉                                           | 6001/21210 [16:30<11:12:16,  2.65s/it]

{'loss/eval': 3.1808106899261475, 'perplexity': 24.066255569458008}


 29%|██████████████████                                             | 6100/21210 [16:45<38:30,  6.54it/s]

{'samples': 780800, 'steps': 762, 'loss/train': 3.1977295875549316, 'step_time': 27.16975362226367}


 29%|██████████████████▍                                            | 6200/21210 [17:00<38:21,  6.52it/s]

{'samples': 793600, 'steps': 774, 'loss/train': 3.2268457412719727, 'step_time': 15.257495919242501}


 30%|██████████████████▋                                            | 6300/21210 [17:15<37:51,  6.56it/s]

{'samples': 806400, 'steps': 787, 'loss/train': 3.1392452716827393, 'step_time': 15.25028233230114}


 30%|███████████████████                                            | 6400/21210 [17:31<37:49,  6.52it/s]

{'samples': 819200, 'steps': 799, 'loss/train': 3.184619665145874, 'step_time': 15.245156766846776}


 31%|███████████████████▎                                           | 6500/21210 [17:46<37:24,  6.55it/s]

{'samples': 832000, 'steps': 812, 'loss/train': 3.1746435165405273, 'step_time': 15.263730630278587}


 31%|███████████████████▌                                           | 6600/21210 [18:01<37:17,  6.53it/s]

{'samples': 844800, 'steps': 824, 'loss/train': 3.1499149799346924, 'step_time': 15.245827831327915}


 32%|███████████████████▉                                           | 6700/21210 [18:16<36:56,  6.55it/s]

{'samples': 857600, 'steps': 837, 'loss/train': 3.1714203357696533, 'step_time': 15.245092744007707}


 32%|████████████████████▏                                          | 6800/21210 [18:32<36:49,  6.52it/s]

{'samples': 870400, 'steps': 849, 'loss/train': 3.170987129211426, 'step_time': 15.254761820659041}


 33%|████████████████████▍                                          | 6900/21210 [18:47<36:12,  6.59it/s]

{'samples': 883200, 'steps': 862, 'loss/train': 3.1577491760253906, 'step_time': 15.3532350435853}


 33%|████████████████████▊                                          | 6999/21210 [19:02<36:05,  6.56it/s]

{'samples': 896000, 'steps': 874, 'loss/train': 3.1250052452087402, 'step_time': 15.210851140320301}


 33%|███████████████████▊                                        | 7001/21210 [19:14<10:27:46,  2.65s/it]

{'loss/eval': 3.1711606979370117, 'perplexity': 23.835134506225586}


 33%|█████████████████████                                          | 7100/21210 [19:29<36:00,  6.53it/s]

{'samples': 908800, 'steps': 887, 'loss/train': 3.1734161376953125, 'step_time': 27.26629125699401}


 34%|█████████████████████▍                                         | 7200/21210 [19:45<36:01,  6.48it/s]

{'samples': 921600, 'steps': 899, 'loss/train': 3.1736364364624023, 'step_time': 15.372930744662881}


 34%|█████████████████████▋                                         | 7300/21210 [20:00<35:27,  6.54it/s]

{'samples': 934400, 'steps': 912, 'loss/train': 3.16807222366333, 'step_time': 15.370019268244505}


 35%|█████████████████████▉                                         | 7400/21210 [20:15<35:24,  6.50it/s]

{'samples': 947200, 'steps': 924, 'loss/train': 3.191721200942993, 'step_time': 15.268311956897378}


 35%|██████████████████████▎                                        | 7500/21210 [20:31<34:47,  6.57it/s]

{'samples': 960000, 'steps': 937, 'loss/train': 3.185987949371338, 'step_time': 15.26256917975843}


 36%|██████████████████████▌                                        | 7600/21210 [20:46<34:39,  6.55it/s]

{'samples': 972800, 'steps': 949, 'loss/train': 3.2321977615356445, 'step_time': 15.253172105178237}


 36%|██████████████████████▊                                        | 7700/21210 [21:01<34:12,  6.58it/s]

{'samples': 985600, 'steps': 962, 'loss/train': 3.237745523452759, 'step_time': 15.251357624307275}


 37%|███████████████████████▏                                       | 7800/21210 [21:16<34:15,  6.52it/s]

{'samples': 998400, 'steps': 974, 'loss/train': 3.1582398414611816, 'step_time': 15.250601727515459}


 37%|███████████████████████▍                                       | 7900/21210 [21:32<34:05,  6.51it/s]

{'samples': 1011200, 'steps': 987, 'loss/train': 3.189488410949707, 'step_time': 15.314238043501973}


 38%|███████████████████████▊                                       | 7999/21210 [21:47<33:52,  6.50it/s]

{'samples': 1024000, 'steps': 999, 'loss/train': 3.2039995193481445, 'step_time': 15.371865209192038}


 38%|███████████████████████                                      | 8001/21210 [21:59<9:44:52,  2.66s/it]

{'loss/eval': 3.161325454711914, 'perplexity': 23.601858139038086}


 38%|████████████████████████                                       | 8100/21210 [22:14<33:24,  6.54it/s]

{'samples': 1036800, 'steps': 1012, 'loss/train': 3.1723427772521973, 'step_time': 27.25947868824005}


 39%|████████████████████████▎                                      | 8200/21210 [22:30<33:13,  6.53it/s]

{'samples': 1049600, 'steps': 1024, 'loss/train': 3.197455406188965, 'step_time': 15.271998165175319}


 39%|████████████████████████▋                                      | 8300/21210 [22:45<33:10,  6.49it/s]

{'samples': 1062400, 'steps': 1037, 'loss/train': 3.19858455657959, 'step_time': 15.217492207884789}


 40%|████████████████████████▉                                      | 8400/21210 [23:00<32:53,  6.49it/s]

{'samples': 1075200, 'steps': 1049, 'loss/train': 3.1176443099975586, 'step_time': 15.258848709985614}


 40%|█████████████████████████▏                                     | 8500/21210 [23:15<32:28,  6.52it/s]

{'samples': 1088000, 'steps': 1062, 'loss/train': 3.1852824687957764, 'step_time': 15.299051277339458}


 41%|█████████████████████████▌                                     | 8600/21210 [23:31<32:15,  6.52it/s]

{'samples': 1100800, 'steps': 1074, 'loss/train': 3.130291700363159, 'step_time': 15.26743215508759}


 41%|█████████████████████████▊                                     | 8700/21210 [23:46<31:45,  6.56it/s]

{'samples': 1113600, 'steps': 1087, 'loss/train': 3.1873350143432617, 'step_time': 15.256402691826224}


 41%|██████████████████████████▏                                    | 8800/21210 [24:01<31:46,  6.51it/s]

{'samples': 1126400, 'steps': 1099, 'loss/train': 3.2136147022247314, 'step_time': 15.218109507113695}


 42%|██████████████████████████▍                                    | 8900/21210 [24:16<31:19,  6.55it/s]

{'samples': 1139200, 'steps': 1112, 'loss/train': 3.1704702377319336, 'step_time': 15.255531709641218}


 42%|██████████████████████████▋                                    | 8999/21210 [24:32<31:00,  6.56it/s]

{'samples': 1152000, 'steps': 1124, 'loss/train': 3.213700294494629, 'step_time': 15.266179107129574}


 42%|█████████████████████████▉                                   | 9001/21210 [24:44<8:59:41,  2.65s/it]

{'loss/eval': 3.154001235961914, 'perplexity': 23.429624557495117}


 43%|███████████████████████████                                    | 9100/21210 [24:59<31:02,  6.50it/s]

{'samples': 1164800, 'steps': 1137, 'loss/train': 3.153841495513916, 'step_time': 27.293770868331194}


 43%|███████████████████████████▎                                   | 9200/21210 [25:14<30:31,  6.56it/s]

{'samples': 1177600, 'steps': 1149, 'loss/train': 3.168224334716797, 'step_time': 15.245415227487683}


 44%|███████████████████████████▌                                   | 9300/21210 [25:29<30:20,  6.54it/s]

{'samples': 1190400, 'steps': 1162, 'loss/train': 3.187918186187744, 'step_time': 15.23156487569213}


 44%|███████████████████████████▉                                   | 9400/21210 [25:45<30:16,  6.50it/s]

{'samples': 1203200, 'steps': 1174, 'loss/train': 3.1073389053344727, 'step_time': 15.294377218931913}


 45%|████████████████████████████▏                                  | 9500/21210 [26:00<29:45,  6.56it/s]

{'samples': 1216000, 'steps': 1187, 'loss/train': 3.1460232734680176, 'step_time': 15.293772408738732}


 45%|████████████████████████████▌                                  | 9600/21210 [26:15<29:59,  6.45it/s]

{'samples': 1228800, 'steps': 1199, 'loss/train': 3.1100564002990723, 'step_time': 15.40306742861867}


 46%|████████████████████████████▊                                  | 9700/21210 [26:31<29:36,  6.48it/s]

{'samples': 1241600, 'steps': 1212, 'loss/train': 3.14378023147583, 'step_time': 15.412749756127596}


 46%|█████████████████████████████                                  | 9800/21210 [26:46<29:18,  6.49it/s]

{'samples': 1254400, 'steps': 1224, 'loss/train': 3.1442034244537354, 'step_time': 15.34150379896164}


 47%|█████████████████████████████▍                                 | 9900/21210 [27:02<28:43,  6.56it/s]

{'samples': 1267200, 'steps': 1237, 'loss/train': 3.1263818740844727, 'step_time': 15.30252168700099}


 47%|█████████████████████████████▋                                 | 9999/21210 [27:17<28:28,  6.56it/s]

{'samples': 1280000, 'steps': 1249, 'loss/train': 3.1569223403930664, 'step_time': 15.22902194224298}


 47%|████████████████████████████▎                               | 10001/21210 [27:29<8:15:23,  2.65s/it]

{'loss/eval': 3.1471505165100098, 'perplexity': 23.269662857055664}


 48%|█████████████████████████████▌                                | 10100/21210 [27:44<28:06,  6.59it/s]

{'samples': 1292800, 'steps': 1262, 'loss/train': 3.1553306579589844, 'step_time': 27.113045040518045}


 48%|█████████████████████████████▊                                | 10200/21210 [27:59<28:01,  6.55it/s]

{'samples': 1305600, 'steps': 1274, 'loss/train': 3.1687843799591064, 'step_time': 15.218311298638582}


 49%|██████████████████████████████                                | 10300/21210 [28:14<27:37,  6.58it/s]

{'samples': 1318400, 'steps': 1287, 'loss/train': 3.1538243293762207, 'step_time': 15.242386164143682}


 49%|██████████████████████████████▍                               | 10400/21210 [28:30<27:35,  6.53it/s]

{'samples': 1331200, 'steps': 1299, 'loss/train': 3.152729034423828, 'step_time': 15.22530066780746}


 50%|██████████████████████████████▋                               | 10500/21210 [28:45<27:19,  6.53it/s]

{'samples': 1344000, 'steps': 1312, 'loss/train': 3.1476078033447266, 'step_time': 15.265525296330452}


 50%|██████████████████████████████▉                               | 10600/21210 [29:00<27:08,  6.51it/s]

{'samples': 1356800, 'steps': 1324, 'loss/train': 3.213650941848755, 'step_time': 15.338895935565233}


 50%|███████████████████████████████▎                              | 10700/21210 [29:15<26:41,  6.56it/s]

{'samples': 1369600, 'steps': 1337, 'loss/train': 3.1711559295654297, 'step_time': 15.251624524593353}


 51%|███████████████████████████████▌                              | 10800/21210 [29:31<26:48,  6.47it/s]

{'samples': 1382400, 'steps': 1349, 'loss/train': 3.167877435684204, 'step_time': 15.302441772073507}


 51%|███████████████████████████████▊                              | 10900/21210 [29:46<26:29,  6.48it/s]

{'samples': 1395200, 'steps': 1362, 'loss/train': 3.1321358680725098, 'step_time': 15.416037414222956}


 52%|████████████████████████████████▏                             | 10999/21210 [30:01<25:55,  6.56it/s]

{'samples': 1408000, 'steps': 1374, 'loss/train': 3.152438163757324, 'step_time': 15.28034739382565}


 52%|███████████████████████████████                             | 11001/21210 [30:13<7:31:27,  2.65s/it]

{'loss/eval': 3.1417407989501953, 'perplexity': 23.144121170043945}


 52%|████████████████████████████████▍                             | 11100/21210 [30:29<25:54,  6.51it/s]

{'samples': 1420800, 'steps': 1387, 'loss/train': 3.1155638694763184, 'step_time': 27.204007586464286}


 53%|████████████████████████████████▋                             | 11200/21210 [30:44<25:48,  6.47it/s]

{'samples': 1433600, 'steps': 1399, 'loss/train': 3.1286826133728027, 'step_time': 15.34654782898724}


 53%|█████████████████████████████████                             | 11300/21210 [30:59<25:26,  6.49it/s]

{'samples': 1446400, 'steps': 1412, 'loss/train': 3.1208393573760986, 'step_time': 15.438314007595181}


 54%|█████████████████████████████████▎                            | 11400/21210 [31:15<25:10,  6.49it/s]

{'samples': 1459200, 'steps': 1424, 'loss/train': 3.1502225399017334, 'step_time': 15.365637289360166}


 54%|█████████████████████████████████▌                            | 11500/21210 [31:30<24:49,  6.52it/s]

{'samples': 1472000, 'steps': 1437, 'loss/train': 3.1974964141845703, 'step_time': 15.24231636337936}


 55%|█████████████████████████████████▉                            | 11600/21210 [31:45<24:38,  6.50it/s]

{'samples': 1484800, 'steps': 1449, 'loss/train': 3.061699151992798, 'step_time': 15.314544109627604}


 55%|██████████████████████████████████▏                           | 11700/21210 [32:01<24:14,  6.54it/s]

{'samples': 1497600, 'steps': 1462, 'loss/train': 3.148198127746582, 'step_time': 15.303769806399941}


 56%|██████████████████████████████████▍                           | 11800/21210 [32:16<24:08,  6.50it/s]

{'samples': 1510400, 'steps': 1474, 'loss/train': 3.158815860748291, 'step_time': 15.393867656588554}


 56%|██████████████████████████████████▊                           | 11900/21210 [32:31<23:48,  6.52it/s]

{'samples': 1523200, 'steps': 1487, 'loss/train': 3.1549363136291504, 'step_time': 15.320684807375073}


 57%|███████████████████████████████████                           | 11999/21210 [32:47<23:37,  6.50it/s]

{'samples': 1536000, 'steps': 1499, 'loss/train': 3.169705390930176, 'step_time': 15.418253680691123}


 57%|█████████████████████████████████▉                          | 12001/21210 [32:59<6:47:58,  2.66s/it]

{'loss/eval': 3.1373205184936523, 'perplexity': 23.042043685913086}


 57%|███████████████████████████████████▎                          | 12100/21210 [33:14<23:28,  6.47it/s]

{'samples': 1548800, 'steps': 1512, 'loss/train': 3.1379098892211914, 'step_time': 27.39389636553824}


 58%|███████████████████████████████████▋                          | 12200/21210 [33:30<23:12,  6.47it/s]

{'samples': 1561600, 'steps': 1524, 'loss/train': 3.168424606323242, 'step_time': 15.432717997580767}


 58%|███████████████████████████████████▉                          | 12300/21210 [33:45<22:37,  6.57it/s]

{'samples': 1574400, 'steps': 1537, 'loss/train': 3.0790977478027344, 'step_time': 15.37229597195983}


 58%|████████████████████████████████████▏                         | 12400/21210 [34:00<22:29,  6.53it/s]

{'samples': 1587200, 'steps': 1549, 'loss/train': 3.1211442947387695, 'step_time': 15.254970889538527}


 59%|████████████████████████████████████▌                         | 12500/21210 [34:15<22:09,  6.55it/s]

{'samples': 1600000, 'steps': 1562, 'loss/train': 3.024564027786255, 'step_time': 15.265192318707705}


 59%|████████████████████████████████████▊                         | 12600/21210 [34:31<21:59,  6.52it/s]

{'samples': 1612800, 'steps': 1574, 'loss/train': 3.1344504356384277, 'step_time': 15.262562835589051}


 60%|█████████████████████████████████████                         | 12700/21210 [34:46<21:36,  6.56it/s]

{'samples': 1625600, 'steps': 1587, 'loss/train': 3.170604705810547, 'step_time': 15.225114531815052}


 60%|█████████████████████████████████████▍                        | 12800/21210 [35:01<21:27,  6.53it/s]

{'samples': 1638400, 'steps': 1599, 'loss/train': 3.127669095993042, 'step_time': 15.22264918871224}


 61%|█████████████████████████████████████▋                        | 12900/21210 [35:16<21:06,  6.56it/s]

{'samples': 1651200, 'steps': 1612, 'loss/train': 3.1071724891662598, 'step_time': 15.249160796403885}


 61%|█████████████████████████████████████▉                        | 12999/21210 [35:32<20:46,  6.59it/s]

{'samples': 1664000, 'steps': 1624, 'loss/train': 3.089346408843994, 'step_time': 15.211756391450763}


 61%|████████████████████████████████████▊                       | 13001/21210 [35:44<6:02:54,  2.65s/it]

{'loss/eval': 3.1314539909362793, 'perplexity': 22.907262802124023}


 62%|██████████████████████████████████████▎                       | 13100/21210 [35:59<20:39,  6.54it/s]

{'samples': 1676800, 'steps': 1637, 'loss/train': 3.1246633529663086, 'step_time': 27.141668424010277}


 62%|██████████████████████████████████████▌                       | 13200/21210 [36:14<20:26,  6.53it/s]

{'samples': 1689600, 'steps': 1649, 'loss/train': 3.179147243499756, 'step_time': 15.248074259608984}


 63%|██████████████████████████████████████▉                       | 13300/21210 [36:29<20:06,  6.56it/s]

{'samples': 1702400, 'steps': 1662, 'loss/train': 3.143482208251953, 'step_time': 15.255143063142896}


 63%|███████████████████████████████████████▏                      | 13400/21210 [36:45<19:55,  6.53it/s]

{'samples': 1715200, 'steps': 1674, 'loss/train': 3.1433305740356445, 'step_time': 15.235099341720343}


 64%|███████████████████████████████████████▍                      | 13500/21210 [37:00<19:31,  6.58it/s]

{'samples': 1728000, 'steps': 1687, 'loss/train': 3.156379461288452, 'step_time': 15.212583372369409}


 64%|███████████████████████████████████████▊                      | 13600/21210 [37:15<19:33,  6.48it/s]

{'samples': 1740800, 'steps': 1699, 'loss/train': 3.132992744445801, 'step_time': 15.333069298416376}


 65%|████████████████████████████████████████                      | 13700/21210 [37:31<19:14,  6.51it/s]

{'samples': 1753600, 'steps': 1712, 'loss/train': 3.1881656646728516, 'step_time': 15.422283098101616}


 65%|████████████████████████████████████████▎                     | 13800/21210 [37:46<19:15,  6.41it/s]

{'samples': 1766400, 'steps': 1724, 'loss/train': 3.1118950843811035, 'step_time': 15.418158555403352}


 66%|████████████████████████████████████████▋                     | 13900/21210 [38:01<18:37,  6.54it/s]

{'samples': 1779200, 'steps': 1737, 'loss/train': 3.143827199935913, 'step_time': 15.349711615592241}


 66%|████████████████████████████████████████▉                     | 13999/21210 [38:17<18:33,  6.47it/s]

{'samples': 1792000, 'steps': 1749, 'loss/train': 3.1177139282226562, 'step_time': 15.445408836007118}


 66%|███████████████████████████████████████▌                    | 14001/21210 [38:29<5:19:24,  2.66s/it]

{'loss/eval': 3.1278398036956787, 'perplexity': 22.824621200561523}


 66%|█████████████████████████████████████████▏                    | 14100/21210 [38:44<18:08,  6.53it/s]

{'samples': 1804800, 'steps': 1762, 'loss/train': 3.0932393074035645, 'step_time': 27.2016465626657}


 67%|█████████████████████████████████████████▌                    | 14200/21210 [38:59<17:54,  6.52it/s]

{'samples': 1817600, 'steps': 1774, 'loss/train': 3.0750949382781982, 'step_time': 15.247157420963049}


 67%|█████████████████████████████████████████▊                    | 14300/21210 [39:15<17:48,  6.46it/s]

{'samples': 1830400, 'steps': 1787, 'loss/train': 3.1447842121124268, 'step_time': 15.46529233083129}


 68%|██████████████████████████████████████████                    | 14400/21210 [39:30<17:38,  6.43it/s]

{'samples': 1843200, 'steps': 1799, 'loss/train': 3.163848876953125, 'step_time': 15.414519060403109}


 68%|██████████████████████████████████████████▍                   | 14500/21210 [39:45<17:02,  6.56it/s]

{'samples': 1856000, 'steps': 1812, 'loss/train': 3.137629270553589, 'step_time': 15.353413233533502}


 69%|██████████████████████████████████████████▋                   | 14600/21210 [40:01<16:49,  6.55it/s]

{'samples': 1868800, 'steps': 1824, 'loss/train': 3.10164213180542, 'step_time': 15.253202827647328}


 69%|██████████████████████████████████████████▉                   | 14700/21210 [40:16<16:28,  6.58it/s]

{'samples': 1881600, 'steps': 1837, 'loss/train': 3.1285181045532227, 'step_time': 15.204800225794315}


 70%|███████████████████████████████████████████▎                  | 14800/21210 [40:31<16:22,  6.52it/s]

{'samples': 1894400, 'steps': 1849, 'loss/train': 3.093383550643921, 'step_time': 15.22760328836739}


 70%|███████████████████████████████████████████▌                  | 14900/21210 [40:46<16:18,  6.45it/s]

{'samples': 1907200, 'steps': 1862, 'loss/train': 3.1227335929870605, 'step_time': 15.303525106981397}


 71%|███████████████████████████████████████████▊                  | 14999/21210 [41:02<15:56,  6.50it/s]

{'samples': 1920000, 'steps': 1874, 'loss/train': 3.0845961570739746, 'step_time': 15.433778071776032}


 71%|██████████████████████████████████████████▍                 | 15000/21210 [41:14<6:27:00,  3.74s/it]

{'loss/eval': 3.1233227252960205, 'perplexity': 22.721752166748047}


 71%|████████████████████████████████████████████▏                 | 15100/21210 [41:29<15:29,  6.57it/s]

{'samples': 1932800, 'steps': 1887, 'loss/train': 3.1608190536499023, 'step_time': 27.27055574208498}


 72%|████████████████████████████████████████████▍                 | 15200/21210 [41:44<15:22,  6.51it/s]

{'samples': 1945600, 'steps': 1899, 'loss/train': 3.1675291061401367, 'step_time': 15.234917480498552}


 72%|████████████████████████████████████████████▋                 | 15300/21210 [42:00<15:16,  6.45it/s]

{'samples': 1958400, 'steps': 1912, 'loss/train': 3.0900611877441406, 'step_time': 15.326861459761858}


 73%|█████████████████████████████████████████████                 | 15400/21210 [42:15<14:56,  6.48it/s]

{'samples': 1971200, 'steps': 1924, 'loss/train': 3.0884909629821777, 'step_time': 15.407085660845041}


 73%|█████████████████████████████████████████████▎                | 15500/21210 [42:30<14:36,  6.52it/s]

{'samples': 1984000, 'steps': 1937, 'loss/train': 3.160662889480591, 'step_time': 15.383156955242157}


 74%|█████████████████████████████████████████████▌                | 15600/21210 [42:46<14:35,  6.41it/s]

{'samples': 1996800, 'steps': 1949, 'loss/train': 3.0851354598999023, 'step_time': 15.393769336864352}


 74%|█████████████████████████████████████████████▉                | 15700/21210 [43:01<14:09,  6.49it/s]

{'samples': 2009600, 'steps': 1962, 'loss/train': 3.1457486152648926, 'step_time': 15.375523759052157}


 74%|██████████████████████████████████████████████▏               | 15800/21210 [43:16<13:44,  6.56it/s]

{'samples': 2022400, 'steps': 1974, 'loss/train': 3.1581778526306152, 'step_time': 15.210724387317896}


 75%|██████████████████████████████████████████████▍               | 15900/21210 [43:32<13:31,  6.54it/s]

{'samples': 2035200, 'steps': 1987, 'loss/train': 3.103574514389038, 'step_time': 15.224882941693068}


 75%|██████████████████████████████████████████████▊               | 15999/21210 [43:47<13:22,  6.49it/s]

{'samples': 2048000, 'steps': 1999, 'loss/train': 3.124329090118408, 'step_time': 15.296480871737003}


 75%|█████████████████████████████████████████████▎              | 16001/21210 [43:59<3:50:20,  2.65s/it]

{'loss/eval': 3.117807626724243, 'perplexity': 22.596784591674805}


 76%|███████████████████████████████████████████████               | 16100/21210 [44:14<13:06,  6.50it/s]

{'samples': 2060800, 'steps': 2012, 'loss/train': 3.1468286514282227, 'step_time': 27.177921291440725}


 76%|███████████████████████████████████████████████▎              | 16200/21210 [44:30<12:52,  6.49it/s]

{'samples': 2073600, 'steps': 2024, 'loss/train': 3.107010841369629, 'step_time': 15.395065939053893}


 77%|███████████████████████████████████████████████▋              | 16300/21210 [44:45<12:38,  6.47it/s]

{'samples': 2086400, 'steps': 2037, 'loss/train': 3.122710704803467, 'step_time': 15.43070762604475}


 77%|███████████████████████████████████████████████▉              | 16400/21210 [45:00<12:23,  6.47it/s]

{'samples': 2099200, 'steps': 2049, 'loss/train': 3.142444372177124, 'step_time': 15.407594319432974}


 78%|████████████████████████████████████████████████▏             | 16500/21210 [45:16<11:56,  6.58it/s]

{'samples': 2112000, 'steps': 2062, 'loss/train': 3.0749363899230957, 'step_time': 15.246076069772243}


 78%|████████████████████████████████████████████████▌             | 16600/21210 [45:31<11:48,  6.50it/s]

{'samples': 2124800, 'steps': 2074, 'loss/train': 3.130856990814209, 'step_time': 15.22500954568386}


 79%|████████████████████████████████████████████████▊             | 16700/21210 [45:46<11:26,  6.57it/s]

{'samples': 2137600, 'steps': 2087, 'loss/train': 3.1743998527526855, 'step_time': 15.246678261086345}


 79%|█████████████████████████████████████████████████             | 16800/21210 [46:01<11:19,  6.49it/s]

{'samples': 2150400, 'steps': 2099, 'loss/train': 3.1225345134735107, 'step_time': 15.240074280649424}


 80%|█████████████████████████████████████████████████▍            | 16900/21210 [46:17<11:07,  6.45it/s]

{'samples': 2163200, 'steps': 2112, 'loss/train': 3.0931649208068848, 'step_time': 15.29965285398066}


 80%|█████████████████████████████████████████████████▋            | 16999/21210 [46:32<10:49,  6.48it/s]

{'samples': 2176000, 'steps': 2124, 'loss/train': 3.153533458709717, 'step_time': 15.388990305364132}


 80%|████████████████████████████████████████████████            | 17001/21210 [46:44<3:06:16,  2.66s/it]

{'loss/eval': 3.1155667304992676, 'perplexity': 22.54620361328125}


 81%|█████████████████████████████████████████████████▉            | 17100/21210 [46:59<10:26,  6.57it/s]

{'samples': 2188800, 'steps': 2137, 'loss/train': 3.118669271469116, 'step_time': 27.165262205526233}


 81%|██████████████████████████████████████████████████▎           | 17200/21210 [47:14<10:14,  6.53it/s]

{'samples': 2201600, 'steps': 2149, 'loss/train': 3.0519793033599854, 'step_time': 15.258133513852954}


 82%|██████████████████████████████████████████████████▌           | 17300/21210 [47:30<09:55,  6.56it/s]

{'samples': 2214400, 'steps': 2162, 'loss/train': 3.1152455806732178, 'step_time': 15.244804853573442}


 82%|██████████████████████████████████████████████████▊           | 17400/21210 [47:45<09:48,  6.47it/s]

{'samples': 2227200, 'steps': 2174, 'loss/train': 3.0590786933898926, 'step_time': 15.355046335607767}


 83%|███████████████████████████████████████████████████▏          | 17500/21210 [48:00<09:32,  6.48it/s]

{'samples': 2240000, 'steps': 2187, 'loss/train': 3.1731083393096924, 'step_time': 15.429462501779199}


 83%|███████████████████████████████████████████████████▍          | 17600/21210 [48:16<09:20,  6.44it/s]

{'samples': 2252800, 'steps': 2199, 'loss/train': 3.153831720352173, 'step_time': 15.443170826882124}


 83%|███████████████████████████████████████████████████▋          | 17700/21210 [48:31<09:00,  6.50it/s]

{'samples': 2265600, 'steps': 2212, 'loss/train': 3.107401132583618, 'step_time': 15.439558029174805}


 84%|████████████████████████████████████████████████████          | 17800/21210 [48:47<08:43,  6.51it/s]

{'samples': 2278400, 'steps': 2224, 'loss/train': 3.0517499446868896, 'step_time': 15.393757276237011}


 84%|████████████████████████████████████████████████████▎         | 17900/21210 [49:02<08:24,  6.56it/s]

{'samples': 2291200, 'steps': 2237, 'loss/train': 3.0367746353149414, 'step_time': 15.226263400167227}


 85%|████████████████████████████████████████████████████▌         | 17999/21210 [49:17<08:08,  6.57it/s]

{'samples': 2304000, 'steps': 2249, 'loss/train': 3.1224770545959473, 'step_time': 15.254784133285284}


 85%|██████████████████████████████████████████████████▉         | 18001/21210 [49:29<2:22:16,  2.66s/it]

{'loss/eval': 3.1131622791290283, 'perplexity': 22.49205780029297}


 85%|████████████████████████████████████████████████████▉         | 18100/21210 [49:45<07:58,  6.50it/s]

{'samples': 2316800, 'steps': 2262, 'loss/train': 3.051994800567627, 'step_time': 27.331344857811928}


 86%|█████████████████████████████████████████████████████▏        | 18200/21210 [50:00<07:40,  6.54it/s]

{'samples': 2329600, 'steps': 2274, 'loss/train': 3.146117687225342, 'step_time': 15.341990092769265}


 86%|█████████████████████████████████████████████████████▍        | 18300/21210 [50:15<07:27,  6.51it/s]

{'samples': 2342400, 'steps': 2287, 'loss/train': 3.079951286315918, 'step_time': 15.375236308202147}


 87%|█████████████████████████████████████████████████████▊        | 18400/21210 [50:31<07:15,  6.46it/s]

{'samples': 2355200, 'steps': 2299, 'loss/train': 3.1435060501098633, 'step_time': 15.403165416792035}


 87%|██████████████████████████████████████████████████████        | 18500/21210 [50:46<06:51,  6.58it/s]

{'samples': 2368000, 'steps': 2312, 'loss/train': 3.080416202545166, 'step_time': 15.28788385540247}


 88%|██████████████████████████████████████████████████████▎       | 18600/21210 [51:01<06:40,  6.51it/s]

{'samples': 2380800, 'steps': 2324, 'loss/train': 3.1020894050598145, 'step_time': 15.251777900382876}


 88%|██████████████████████████████████████████████████████▋       | 18700/21210 [51:16<06:22,  6.56it/s]

{'samples': 2393600, 'steps': 2337, 'loss/train': 3.076526403427124, 'step_time': 15.267626520246267}


 89%|██████████████████████████████████████████████████████▉       | 18800/21210 [51:32<06:10,  6.50it/s]

{'samples': 2406400, 'steps': 2349, 'loss/train': 3.0677390098571777, 'step_time': 15.257199723273516}


 89%|███████████████████████████████████████████████████████▏      | 18900/21210 [51:47<05:52,  6.55it/s]

{'samples': 2419200, 'steps': 2362, 'loss/train': 3.120512008666992, 'step_time': 15.269322069361806}


 90%|███████████████████████████████████████████████████████▌      | 18999/21210 [52:02<05:38,  6.54it/s]

{'samples': 2432000, 'steps': 2374, 'loss/train': 3.1461915969848633, 'step_time': 15.241530722007155}


 90%|█████████████████████████████████████████████████████▊      | 19001/21210 [52:14<1:37:45,  2.66s/it]

{'loss/eval': 3.108614444732666, 'perplexity': 22.389999389648438}


 90%|███████████████████████████████████████████████████████▊      | 19100/21210 [52:30<05:25,  6.48it/s]

{'samples': 2444800, 'steps': 2387, 'loss/train': 3.122971296310425, 'step_time': 27.345556426793337}


 91%|████████████████████████████████████████████████████████      | 19200/21210 [52:45<05:07,  6.53it/s]

{'samples': 2457600, 'steps': 2399, 'loss/train': 3.1511340141296387, 'step_time': 15.250343082472682}


 91%|████████████████████████████████████████████████████████▍     | 19300/21210 [53:00<04:53,  6.50it/s]

{'samples': 2470400, 'steps': 2412, 'loss/train': 3.08945369720459, 'step_time': 15.259023815393448}


 91%|████████████████████████████████████████████████████████▋     | 19400/21210 [53:15<04:37,  6.53it/s]

{'samples': 2483200, 'steps': 2424, 'loss/train': 3.09544038772583, 'step_time': 15.28858694061637}


 92%|█████████████████████████████████████████████████████████     | 19500/21210 [53:31<04:23,  6.49it/s]

{'samples': 2496000, 'steps': 2437, 'loss/train': 3.0728073120117188, 'step_time': 15.37615505233407}


 92%|█████████████████████████████████████████████████████████▎    | 19600/21210 [53:46<04:05,  6.55it/s]

{'samples': 2508800, 'steps': 2449, 'loss/train': 3.0851974487304688, 'step_time': 15.309121139347553}


 93%|█████████████████████████████████████████████████████████▌    | 19700/21210 [54:01<03:49,  6.57it/s]

{'samples': 2521600, 'steps': 2462, 'loss/train': 3.116976022720337, 'step_time': 15.243124719709158}


 93%|█████████████████████████████████████████████████████████▉    | 19800/21210 [54:17<03:37,  6.49it/s]

{'samples': 2534400, 'steps': 2474, 'loss/train': 3.1074485778808594, 'step_time': 15.384843535721302}


 94%|██████████████████████████████████████████████████████████▏   | 19900/21210 [54:32<03:21,  6.49it/s]

{'samples': 2547200, 'steps': 2487, 'loss/train': 3.0793018341064453, 'step_time': 15.384544346481562}


 94%|██████████████████████████████████████████████████████████▍   | 19999/21210 [54:47<03:03,  6.59it/s]

{'samples': 2560000, 'steps': 2499, 'loss/train': 3.1162524223327637, 'step_time': 15.352581853047013}


 94%|██████████████████████████████████████████████████████████▍   | 20001/21210 [55:00<53:32,  2.66s/it]

{'loss/eval': 3.10537052154541, 'perplexity': 22.317485809326172}


 95%|██████████████████████████████████████████████████████████▊   | 20100/21210 [55:15<02:49,  6.53it/s]

{'samples': 2572800, 'steps': 2512, 'loss/train': 3.0948634147644043, 'step_time': 27.230187559500337}


 95%|███████████████████████████████████████████████████████████   | 20200/21210 [55:30<02:35,  6.50it/s]

{'samples': 2585600, 'steps': 2524, 'loss/train': 3.1488614082336426, 'step_time': 15.260149240493774}


 96%|███████████████████████████████████████████████████████████▎  | 20300/21210 [55:45<02:19,  6.55it/s]

{'samples': 2598400, 'steps': 2537, 'loss/train': 3.0367846488952637, 'step_time': 15.263579554855824}


 96%|███████████████████████████████████████████████████████████▋  | 20400/21210 [56:00<02:04,  6.50it/s]

{'samples': 2611200, 'steps': 2549, 'loss/train': 3.143766403198242, 'step_time': 15.285552438348532}


 97%|███████████████████████████████████████████████████████████▉  | 20500/21210 [56:16<01:48,  6.56it/s]

{'samples': 2624000, 'steps': 2562, 'loss/train': 3.070713520050049, 'step_time': 15.224907467141747}


 97%|████████████████████████████████████████████████████████████▏ | 20600/21210 [56:31<01:33,  6.51it/s]

{'samples': 2636800, 'steps': 2574, 'loss/train': 3.1219348907470703, 'step_time': 15.228856148198247}


 98%|████████████████████████████████████████████████████████████▌ | 20700/21210 [56:46<01:17,  6.55it/s]

{'samples': 2649600, 'steps': 2587, 'loss/train': 3.0618746280670166, 'step_time': 15.218804789707065}


 98%|████████████████████████████████████████████████████████████▊ | 20800/21210 [57:01<01:02,  6.52it/s]

{'samples': 2662400, 'steps': 2599, 'loss/train': 3.0801162719726562, 'step_time': 15.211174042895436}


 99%|█████████████████████████████████████████████████████████████ | 20900/21210 [57:17<00:47,  6.56it/s]

{'samples': 2675200, 'steps': 2612, 'loss/train': 3.066216468811035, 'step_time': 15.212190376594663}


 99%|█████████████████████████████████████████████████████████████▍| 20999/21210 [57:32<00:32,  6.53it/s]

{'samples': 2688000, 'steps': 2624, 'loss/train': 3.0875930786132812, 'step_time': 15.259734170511365}


 99%|█████████████████████████████████████████████████████████████▍| 21001/21210 [57:44<09:16,  2.66s/it]

{'loss/eval': 3.1032729148864746, 'perplexity': 22.270721435546875}


 99%|█████████████████████████████████████████████████████████████▋| 21100/21210 [57:59<00:16,  6.52it/s]

{'samples': 2700800, 'steps': 2637, 'loss/train': 3.113971710205078, 'step_time': 27.3056007809937}


100%|█████████████████████████████████████████████████████████████▉| 21200/21210 [58:15<00:01,  6.44it/s]

{'samples': 2713600, 'steps': 2649, 'loss/train': 3.1194911003112793, 'step_time': 15.376722183078527}


100%|██████████████████████████████████████████████████████████████| 21210/21210 [58:16<00:00,  6.07it/s]


In [16]:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
wandb.finish()

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss/eval,█▆▅▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁
loss/train,█▆▅▃▃▃▄▃▃▃▃▃▂▃▃▃▃▂▂▂▂▂▃▂▁▂▁▂▂▂▂▂▁▁▂▂▂▂▂▂
perplexity,█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁
samples,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
step_time,▅▁▃▄▅▅▇█▄██▂▄▇▄▄▃▄▆▃▅▇█▄▃▆▄▄▃▃▇▄▆▇█▄▅▆▃▇
steps,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
loss/eval,3.10327
loss/train,3.11949
perplexity,22.27072
samples,2713600.0
step_time,15.37672
steps,2649.0
