# Language training (one-layers)

In [None]:
import datasets
import math
import numpy as np

import torch
import torch.nn as nn
from tqdm import tqdm

from copy import deepcopy
from transformer_lens import HookedTransformer
from transformer_lens import HookedTransformerConfig
from transformer_lens.utils import lm_cross_entropy_loss
from transformer_lens.utils import tokenize_and_concatenate

In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params} ({total_params//1e6}M)")
    return total_params

In [None]:
MAX_STEPS = 25000
LR = 1e-3
BATCH_SIZE = 100
WEIGHT_DECAY = 0.05

DATASET = 'oknMswoztTPaAVreBrWy/dsir-pile-5m'
DS_COL = 'contents'

MODEL_NAME = 'L1'

DEVICE = 'cuda'

In [None]:
model_cfg = HookedTransformerConfig(
    n_layers=1,
    d_model=256,
    d_head=32,
    n_heads=8,
    n_ctx=1024,
    d_vocab=5000,
    tokenizer_name='oknMswoztTPaAVreBrWy/TinyStories-tokenizer-5k',
    normalization_type='LN',
    attn_only=True,
    seed=1,
    positional_embedding_type='shortformer',
)

model = HookedTransformer(model_cfg) # ~3M params
count_parameters(model)
model.cuda()

optimizer = torch.optim.AdamW(model.parameters(),
                              lr=LR,
                              weight_decay=WEIGHT_DECAY)

In [None]:
dataset = datasets.load_dataset(DATASET,
                                split='train')
tokens_dataset = tokenize_and_concatenate(dataset,
                                          model.tokenizer,
                                          streaming=False,
                                          max_length=model.cfg.n_ctx,
                                          column_name=DS_COL,
                                          add_bos_token=True,
                                          num_proc=12)
data_loader = torch.utils.data.DataLoader(tokens_dataset,
                                          batch_size=BATCH_SIZE,
                                          num_workers=0,
                                          shuffle=True,
                                          pin_memory=True)
print("Number of batches:", len(data_loader))

In [None]:
import tempfile
import json
import os

from huggingface_hub import HfApi

HF_API = HfApi(
    endpoint="https://huggingface.co",
    token=None, # fill with your HF API token
)


def checkpoint_name(step):
  return f'checkpoint_{step:0>7d}.pth'


def save_to_temp_dir(temp_dir, model, optimizer, step):
  folder_path = os.path.join(temp_dir.name, "checkpoints")
  os.makedirs(folder_path, exist_ok=True)
  model_path = os.path.join(folder_path, checkpoint_name(step))
  torch.save({
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      }, model_path)


# return new/clean dir path afterwards
def upload_temp_dir(temp_dir):
  HF_API.upload_folder(
    folder_path=temp_dir.name,
    repo_id=f"oknMswoztTPaAVreBrWy/{MODEL_NAME}",
    repo_type="model",
  )
  temp_dir.cleanup()
  return tempfile.TemporaryDirectory()


def save_losses(losses):
  temp_dir = tempfile.mkdtemp()
  losses_path = os.path.join(temp_dir, 'losses.txt')
  with open(losses_path, 'w') as f:
    for loss in losses:
      f.write(str(loss) + '\n')
  HF_API.upload_file(repo_id=f'oknMswoztTPaAVreBrWy/{MODEL_NAME}', path_or_fileobj = losses_path, path_in_repo = 'losses.txt')

In [None]:
import time
from datetime import datetime
import matplotlib.pyplot as plt

temp_dir = tempfile.TemporaryDirectory()

model = HookedTransformer(model_cfg) # ~3M params
model.cuda()

optimizer = torch.optim.AdamW(model.parameters(),
                              lr=LR,
                              weight_decay=WEIGHT_DECAY)

torch.manual_seed(1)
# have to run all these cells consecutively to correctly reproduce the training run

losses = []

start = datetime.now()

for c, batch in enumerate(data_loader):
  tokens = batch['tokens'].cuda()
  logits = model(tokens)
  loss = lm_cross_entropy_loss(logits, tokens)
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()
  losses.append(loss.item())

  if c % 100 == 0:
    print(f"Step: {c}, Loss: {loss.item():.4f}")
    save_to_temp_dir(temp_dir, model, optimizer, c)
  if c % 500 == 0 and c > 0:
    temp_dir = upload_temp_dir(temp_dir)
    split_time = datetime.now() - start
    start = datetime.now()
    print(f'Time since last log: {split_time}')
    print(f'Estimated time remaining: {split_time * ((MAX_STEPS  - c) / 500)}')
    save_losses(losses)
    plt.plot(losses)
    plt.show()
  if c >= MAX_STEPS:
    break

In [None]:
DATASET2 = 'oknMswoztTPaAVreBrWy/dsir-pile-5m-2'

dataset = datasets.load_dataset(DATASET2,
                                split='train')
tokens_dataset = tokenize_and_concatenate(dataset,
                                          model.tokenizer,
                                          streaming=False,
                                          max_length=model.cfg.n_ctx,
                                          column_name=DS_COL,
                                          add_bos_token=True,
                                          num_proc=12)
data_loader = torch.utils.data.DataLoader(tokens_dataset,
                                          batch_size=BATCH_SIZE,
                                          num_workers=0,
                                          shuffle=True,
                                          pin_memory=True)
print("Number of batches:", len(data_loader))

In [None]:
# optimizer was reset between training halves
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=LR,
                              weight_decay=WEIGHT_DECAY)

In [None]:
temp_dir = tempfile.TemporaryDirectory()

start = datetime.now()

for c, batch in enumerate(data_loader):
  if c == 0:
    continue
  tokens = batch['tokens'].cuda()
  logits = model(tokens)
  loss = lm_cross_entropy_loss(logits, tokens)
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()
  losses.append(loss.item())

  if c % 100 == 0:
    print(f"Step: {c}, Loss: {loss.item():.4f}")
    save_to_temp_dir(temp_dir, model, optimizer, c)
  if c % 500 == 0 and c > 0:
    temp_dir = upload_temp_dir(temp_dir)
    split_time = datetime.now() - start
    start = datetime.now()
    print(f'Time since last log: {split_time}')
    print(f'Estimated time remaining: {split_time * ((MAX_STEPS  - c) / 500)}')
    save_losses(losses)
    plt.plot(losses)
    plt.show()
  if c >= MAX_STEPS:
    break