In [63]:
! pip install transformers accelerate datasets tqdm wandb
! wandb login



In [64]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, get_scheduler
import os
import torch
from torch.optim import AdamW
from transformers import DataCollatorForLanguageModeling
import time

In [65]:
use_embedding = True
use_custom_attn_mask = True
num_epochs = 1
window_size = 2
batch_size = 128
block_size = 128

In [66]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

In [67]:
# https://huggingface.co/learn/nlp-course/en/chapter7/6?fw=pt
from torch.nn import CrossEntropyLoss
import torch

def causal_lm_loss(inputs, logits, alpha=1.0):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(reduce=False)
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    # Resize and average loss per sample
    loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)
    # Calculate average
    loss = loss_per_sample.mean()
    return loss

In [68]:
from datasets import load_dataset

def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=block_size,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == block_size:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

ds = load_dataset("roneneldan/TinyStories")
tokenized_ds = ds.map(tokenize, batched=True, batch_size=batch_size, num_proc=4, remove_columns=ds["train"].column_names)
tokenized_ds.set_format("torch")

Repo card metadata block was not found. Setting CardData to empty.


In [69]:
from torch.utils.data.dataloader import DataLoader

train_dataloader = DataLoader(tokenized_ds["train"], batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(tokenized_ds["validation"], batch_size=batch_size)

In [70]:
print(len(eval_dataloader))

212


In [71]:
weight_decay = 0.1

def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [72]:
model = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-3M')

model.to(device)

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 128)
    (wpe): Embedding(2048, 128)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPTNeoBlock(
        (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=128, out_features=128, bias=False)
            (v_proj): Linear(in_features=128, out_features=128, bias=False)
            (q_proj): Linear(in_features=128, out_features=128, bias=False)
            (out_proj): Linear(in_features=128, out_features=128, bias=True)
          )
        )
        (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=128, out_features=512, bias=True)
          (c_proj): Linear(in_featu

In [73]:
model.config

GPTNeoConfig {
  "_name_or_path": "roneneldan/TinyStories-3M",
  "activation_function": "gelu_new",
  "architectures": [
    "GPTNeoForCausalLM"
  ],
  "attention_dropout": 0,
  "attention_layers": [
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local"
  ],
  "attention_types": [
    [
      [
        "global",
        "local"
      ],
      4
    ]
  ],
  "bos_token_id": 50256,
  "classifier_dropout": 0.1,
  "embed_dropout": 0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "hidden_size": 128,
  "initializer_range": 0.02,
  "intermediate_size": null,
  "layer_norm_epsilon": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "gpt_neo",
  "num_heads": 16,
  "num_layers": 8,
  "resid_dropout": 0,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "torch_dtype": "float32",
  "transformers_version": "4.42.3",
 

In [74]:
from utils import BeaconEmbedding, generate_beacon_attention_mask_2d

if use_embedding:
    beacon_embedding = BeaconEmbedding(embedding=model.get_input_embeddings(), vocab_size=model.config.vocab_size, n_embed=model.config.hidden_size, window_length=2)
    model.set_input_embeddings(beacon_embedding)

beacon_attention_mask = generate_beacon_attention_mask_2d(block_size, window_length=window_size, device=device)
beacon_attention_mask = beacon_attention_mask.unsqueeze(0).repeat(batch_size, 1, 1)


optimizer = AdamW(model.parameters(), lr=5e-5)
num_training_steps = num_epochs * len(train_dataloader)

In [75]:
print(batch_size, beacon_attention_mask.shape)

128 torch.Size([128, 128, 128])


In [76]:
beacon_attention_mask[0, :10, :10]

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]], device='cuda:0')

In [77]:
def evaluate():
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])

        losses.append(accelerator.gather(outputs.loss))
    print(losses)
    loss = torch.mean(torch.Tensor(losses))
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [78]:
from accelerate import Accelerator
import wandb

accelerator = Accelerator() # Logging with wandb here isn't working as expected for some reason
wandb.init(
    project="beacon-attention",
    config={
        "use_custom_embedding": use_embedding,
        "use_custom_attn_mask": use_custom_attn_mask,
        "window_size": window_size
    }
)

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader)

BrokenPipeError: [Errno 32] Broken pipe

In [None]:
model_name = f"{'beacon_embed' if use_embedding else 'no_beacon_embed'}_{'beacon_attn_mask' if use_custom_attn_mask else 'regular_attn_mask'}_window_size_{window_size}_model"
output_dir = f"./models/{model_name}"
if not(os.path.exists(output_dir)):
    os.makedirs(output_dir)

In [None]:
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [None]:
from tqdm import tqdm

gradient_accumulation_steps = 8
eval_steps = 5000

model.train()

completed_steps = 0
step_start_time = time.perf_counter()

for epoch in range(num_epochs):
    for step, batch in tqdm(enumerate(train_dataloader, start=1), total=num_training_steps):
        T = batch['input_ids'].shape[1] # B, T
        if use_custom_attn_mask:
            logits = model(input_ids=batch["input_ids"], attention_mask=beacon_attention_mask).logits
        else:
            logits = model(input_ids=batch["input_ids"]).logits
        loss = causal_lm_loss(batch["input_ids"], logits)

        if step % 100 == 0:
            step_end_time = time.perf_counter()
            train_update = {
                "samples": step * batch_size,
                "steps": completed_steps,
                "loss/train": loss.item(), # * gradient_accumulation_steps,
                "step_time": step_end_time - step_start_time
            }
            accelerator.print(train_update)
            wandb.log(train_update)
            step_start_time = step_end_time

        loss = loss / gradient_accumulation_steps
        accelerator.backward(loss)

        if step % gradient_accumulation_steps == 0:
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            completed_steps += 1
        if (step % 1000) == 0:
            eval_loss, perplexity = evaluate()
            eval_update = {"loss/eval": eval_loss, "perplexity": perplexity}
            accelerator.print(eval_update)
            wandb.log(eval_update)
            model.train()
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
        
        end_time = time.perf_counter()

In [None]:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
wandb.finish()