In [1]:
!pip install torch datasets transformers huggingface_hub wandb hf_transfer



In [2]:
import os
import sys
import torch
import wandb
from torch.amp import autocast, GradScaler
from torch.optim import AdamW
import math, time
from create_shards import split_and_shard, load_split
from model import BasicGPT, ModelConfig
from create_dataloaders import build_packed_dataloaders
from transformers import AutoTokenizer



In [3]:
from huggingface_hub import login
sys.path.append("/workspace/.local")
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
!hf auth whoami

[1muser: [0m timdadum


In [5]:
os.environ["WANDB_API_KEY"] = "key"
wandb.init(project="runpod-llm-recursion")
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="timdadum-personal",
    # Set the wandb project where this run will be logged.
    project="my-awesome-project",
    # Track hyperparameters and run metadata.
    # config={
    #     "learning_rate": 0.02,
    #     "architecture": "CNN",
    #     "dataset": "CIFAR-100",
    #     "epochs": 10,
    # },
)

[34m[1mwandb[0m: Currently logged in as: [33mtimdadum[0m ([33mtimdadum-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
def _extract_loss(output):
    """
    Robustly extract a scalar loss from various output shapes:
    - CausalLMOutput-like objects (have .loss)
    - dicts with 'loss'
    - tuples/lists with loss in first position
    """
    if hasattr(output, "loss") and output.loss is not None:
        return output.loss
    if isinstance(output, dict) and "loss" in output:
        return output["loss"]
    if isinstance(output, (tuple, list)) and len(output) > 0:
        return output[0]
    raise ValueError("Could not extract loss from model output.")

@torch.no_grad()
def evaluate(dataloader, max_batches=None):
    model.eval()
    losses, n = 0.0, 0

    for i, batch in enumerate(dataloader):
        if max_batches and i >= max_batches:
            break

        ids = batch.get("input_ids").to(device)
        mask = batch.get("attention_mask")
        labels = batch.get("labels")

        if mask is not None:
            mask = mask.to(device)
        if labels is not None:
            labels = labels.to(device)

        out = model(input_ids=ids, attention_mask=mask, labels=labels)
        loss = _extract_loss(out)
        losses += float(loss.item())
        n += 1

    model.train()
    avg = losses / max(1, n)
    return avg, math.exp(avg)

def train(config, train_loader, eval_loader, run_name, epochs=3, log_every=50, eval_every=500):
    global_step = 0
    best_perplexity = float("inf")
    print("[TRAINING]: Starting training...")
    model.to(device).train()
    t0 = time.time()
    running = 0.0

    for epoch in range(epochs):
        print(f"[TRAINING]: Starting epoch {epoch+1}/{epochs}...")
        for step, batch in enumerate(train_loader, start=1):
            # Prepare tensors (guard for None masks)
            x = batch.get("input_ids").to(device)
            mask = batch.get("attention_mask")
            if mask is not None:
                mask = mask.to(device)

            # Forward + loss
            optimizer.zero_grad(set_to_none=True)
            amp_ctx = autocast(device_type=device, enabled=(device == "cuda" and config.amp))
            with amp_ctx:
                out = model(input_ids=x, attention_mask=mask, labels=x)
                loss = _extract_loss(out)

            wandb.log({"train_loss": loss, "epoch": epoch})

            # Backward + update
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
            scaler.step(optimizer)
            scaler.update()

            global_step += 1
            running += float(loss.item())  # accumulate for logging window

            # Log every N steps (use the function arg 'log_every', not config.log_every)
            if global_step % log_every == 0:
                print(f"[TRAINING]: epoch {epoch+1} | step {global_step} | loss {running / log_every:.4f}")
                running = 0.0

            # Eval every M steps
            if global_step % eval_every == 0:
                eval_loss, eval_ppl = evaluate(eval_loader)
                print(f"[TRAINING - EVAL] step {global_step} | val_loss {eval_loss:.4f} | val_ppl {eval_ppl:.2f}")
                if eval_ppl < best_perplexity:
                    best_perplexity = eval_ppl
                    path = os.path.join(config.out_dir, f"{run_name}_best.pt")
                    torch.save({
                        "config": vars(config),
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "step": global_step,
                        "val_ppl": eval_ppl,
                    }, path)
                    print("Saved best:", path)

    # Final eval + save
    final_eval_loss, final_eval_ppl = evaluate(eval_loader)
    elapsed = time.time() - t0
    print(f"Epoch {epochs} in {elapsed:.1f}s | val_loss {final_eval_loss:.4f} | val_ppl {final_eval_ppl:.2f}")
    path = os.path.join(config.out_dir, f"{run_name}_e{epochs}_step{global_step}.pt")
    torch.save({
        "config": vars(config),
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "step": global_step,
        "val_ppl": final_eval_ppl,
    }, path)
    print("Saved:", path)

In [None]:
DO_TRAIN = False
BASE = f"/recursion"
config = ModelConfig(
    d_model = 1028,
    max_len = 512,
    project_name = "basic-llm-d512-l512-10e-[D]25%en"
)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if input("Train? (Y/N)").lower() == 'y':
  DO_TRAIN = True

  if input("Load existing model? (Y/N)") == 'y':
    model, checkpoint = BasicGPT.load_weights(f"{BASE}/trained_models/simple_baseline_best.pt", config)
    model.to(device)
    print("Successfully loaded model weights from earlier run. Continuing...")
  else:
    print("Initialized model. Continuing...")    
    model = BasicGPT(config)
    print(model)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model params: {total_params/1e6:.2f}M")

  split_and_shard(
    ds_name="manu/project_gutenberg", # "Cohere/wikipedia-2023-11-embed-multilingual-v3",
    # config="en",
    hf_split="en",
    out_dir=f"{BASE}/data/shards",
    max_shard_size="1024MB",
    num_proc=16
  )

  train_loader = load_split(f"{BASE}/data/shards", "train", shard_by_rank=True, shuffle_seed=42)
  eval_loader   = load_split(f"{BASE}/data/shards", "val",   shard_by_rank=True, shuffle_seed=43)
  test_loader  = load_split(f"{BASE}/data/shards", "test",  shard_by_rank=True, shuffle_seed=44)

  train_loader, eval_loader, test_loader = build_packed_dataloaders(
    train_loader, eval_loader, test_loader,
    tokenizer_name="gpt2",
    seq_len=256,
    batch_size=32,
  )

  # Peek
  xb, yb = next(iter(train_loader))["input_ids"], next(iter(train_loader))["labels"]
  print("Sample batch:", xb.shape, yb.shape)

  print("Sample item (token integers):", xb[1], yb.shape[1])

  text = tokenizer.decode(xb[1].tolist(), skip_special_tokens=False)
  print(f"Decoded sample batch: {text}")

  optimizer = AdamW(model.parameters(),
                    lr=config.lr,
                    betas=config.betas,
                    weight_decay=config.weight_decay)

  scaler = GradScaler(enabled=(device == "cuda" and getattr(config, "amp", True)))

  train(config, train_loader, eval_loader, run_name="simple_baseline", epochs=10)
else:
  if input("Load existing model? (Y/N)") == 'y':
    model, checkpoint = BasicGPT.load_weights(f"{BASE}/trained_models/simple_baseline_best.pt", config)
    print(model)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model params: {total_params/1e6:.2f}M")
    model.to(device)
    print("Successfully loaded model.")


Using device: cuda


Train? (Y/N) y
Load existing model? (Y/N) n


Initialized model. Continuing...
BasicGPT(
  (token_emb): Embedding(50257, 1028)
  (pos_emb): Embedding(512, 1028)
  (drop): Dropout(p=0.1, inplace=False)
  (blocks): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1028, out_features=1028, bias=True)
        )
        (linear1): Linear(in_features=1028, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=1028, bias=True)
        (norm1): LayerNorm((1028,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((1028,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm_f): LayerNorm((1028,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=1028, out_features=50257, bias=False

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/52 [00:00<?, ?files/s]

data/en-00002-of-00052-835bd07d97f52cbd.(…):   0%|          | 0.00/215M [00:00<?, ?B/s]

data/en-00003-of-00052-3827386b583e4d76.(…):   0%|          | 0.00/239M [00:00<?, ?B/s]

data/en-00004-of-00052-a2f24c4fe858fe0f.(…):   0%|          | 0.00/268M [00:00<?, ?B/s]

data/en-00005-of-00052-2a13fc98474cabed.(…):   0%|          | 0.00/277M [00:00<?, ?B/s]

data/en-00006-of-00052-81d2618caede0093.(…):   0%|          | 0.00/195M [00:00<?, ?B/s]

data/en-00007-of-00052-de0de442f370b789.(…):   0%|          | 0.00/217M [00:00<?, ?B/s]

data/en-00008-of-00052-a83055e8ef415c07.(…):   0%|          | 0.00/244M [00:00<?, ?B/s]

data/en-00009-of-00052-f2f126633fa25668.(…):   0%|          | 0.00/237M [00:00<?, ?B/s]

data/en-00010-of-00052-2226b3722aa696eb.(…):   0%|          | 0.00/276M [00:00<?, ?B/s]

data/en-00011-of-00052-6c9ae05ed451701f.(…):   0%|          | 0.00/319M [00:00<?, ?B/s]

data/en-00012-of-00052-2de5b14941be3266.(…):   0%|          | 0.00/201M [00:00<?, ?B/s]

data/en-00013-of-00052-a66a5e317603bb21.(…):   0%|          | 0.00/234M [00:00<?, ?B/s]

data/en-00014-of-00052-e976ff9fa7c0a4c2.(…):   0%|          | 0.00/234M [00:00<?, ?B/s]

data/en-00015-of-00052-9a9fd49be8a70a6c.(…):   0%|          | 0.00/257M [00:00<?, ?B/s]

data/en-00016-of-00052-5006e8c00e35ad72.(…):   0%|          | 0.00/278M [00:00<?, ?B/s]

data/en-00017-of-00052-c37121d3035604a6.(…):   0%|          | 0.00/196M [00:00<?, ?B/s]

data/en-00018-of-00052-76fc57ebfaac39a2.(…):   0%|          | 0.00/216M [00:00<?, ?B/s]

data/en-00019-of-00052-05068b06d8a4ffb7.(…):   0%|          | 0.00/235M [00:00<?, ?B/s]

data/en-00020-of-00052-31ef1cece5305678.(…):   0%|          | 0.00/244M [00:00<?, ?B/s]

data/en-00021-of-00052-82812d42cefe9a5b.(…):   0%|          | 0.00/279M [00:00<?, ?B/s]

data/en-00022-of-00052-061a44c5aeff4f98.(…):   0%|          | 0.00/297M [00:00<?, ?B/s]

data/en-00023-of-00052-380917a82781d4aa.(…):   0%|          | 0.00/201M [00:00<?, ?B/s]

data/en-00024-of-00052-fb2f9a960ee8c75e.(…):   0%|          | 0.00/221M [00:00<?, ?B/s]

data/en-00025-of-00052-9c00570871767dea.(…):   0%|          | 0.00/228M [00:00<?, ?B/s]

data/en-00026-of-00052-8719637a331ec653.(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/en-00027-of-00052-9d8bfd7843e9718f.(…):   0%|          | 0.00/277M [00:00<?, ?B/s]

data/en-00028-of-00052-2e237a8d8b810dec.(…):   0%|          | 0.00/209M [00:00<?, ?B/s]

data/en-00029-of-00052-780d5f400fd85afa.(…):   0%|          | 0.00/199M [00:00<?, ?B/s]

data/en-00030-of-00052-cfdc6f381f17e852.(…):   0%|          | 0.00/231M [00:00<?, ?B/s]

data/en-00031-of-00052-e7f5d815a26b08d0.(…):   0%|          | 0.00/232M [00:00<?, ?B/s]

data/en-00032-of-00052-4ed9cbf89e3d13b5.(…):   0%|          | 0.00/263M [00:00<?, ?B/s]

data/en-00033-of-00052-4c525b2c5bfc7b3f.(…):   0%|          | 0.00/277M [00:00<?, ?B/s]

data/en-00034-of-00052-0f04bcad9d91ea41.(…):   0%|          | 0.00/197M [00:00<?, ?B/s]

data/en-00035-of-00052-617bd55ce8eaaf8c.(…):   0%|          | 0.00/207M [00:00<?, ?B/s]

data/en-00036-of-00052-dea698a178ee5475.(…):   0%|          | 0.00/222M [00:00<?, ?B/s]

data/en-00037-of-00052-80239e718491affb.(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/en-00038-of-00052-671117f0cc621546.(…):   0%|          | 0.00/275M [00:00<?, ?B/s]

data/en-00039-of-00052-f766a27e24b911d6.(…):   0%|          | 0.00/192M [00:00<?, ?B/s]

data/en-00040-of-00052-5f304ed689a13135.(…):   0%|          | 0.00/197M [00:00<?, ?B/s]

data/en-00041-of-00052-6f04cd012627fa08.(…):   0%|          | 0.00/210M [00:00<?, ?B/s]

data/en-00042-of-00052-94e90b865e11f015.(…):   0%|          | 0.00/221M [00:00<?, ?B/s]

data/en-00043-of-00052-545c5249d7f68142.(…):   0%|          | 0.00/254M [00:00<?, ?B/s]

data/en-00044-of-00052-2f8a81e4b2cb26bd.(…):   0%|          | 0.00/270M [00:00<?, ?B/s]

data/en-00045-of-00052-bbb08eac7b16b553.(…):   0%|          | 0.00/185M [00:00<?, ?B/s]

data/en-00046-of-00052-fbf5b6f877101255.(…):   0%|          | 0.00/200M [00:00<?, ?B/s]

data/en-00047-of-00052-b53a9aba6fd08df1.(…):   0%|          | 0.00/220M [00:00<?, ?B/s]

data/en-00048-of-00052-c3e4665ddb21ff40.(…):   0%|          | 0.00/238M [00:00<?, ?B/s]

data/en-00049-of-00052-238a44ba6d899475.(…):   0%|          | 0.00/256M [00:00<?, ?B/s]

data/en-00050-of-00052-119c77546b4d5bc3.(…):   0%|          | 0.00/294M [00:00<?, ?B/s]

data/en-00051-of-00052-416e3a1d8d8e7d86.(…):   0%|          | 0.00/200M [00:00<?, ?B/s]

data/es-00000-of-00001-ad684e007393cf76.(…):   0%|          | 0.00/195M [00:00<?, ?B/s]

data/fr-00000-of-00005-a475c3836a0ce5b5.(…):   0%|          | 0.00/189M [00:00<?, ?B/s]

data/fr-00001-of-00005-b6213362f858795a.(…):   0%|          | 0.00/185M [00:00<?, ?B/s]

data/fr-00002-of-00005-cf81716abdc38e6a.(…):   0%|          | 0.00/202M [00:00<?, ?B/s]

data/fr-00003-of-00005-32d4f9159674f920.(…):   0%|          | 0.00/200M [00:00<?, ?B/s]

data/fr-00004-of-00005-6c2bd994bbdaed75.(…):   0%|          | 0.00/199M [00:00<?, ?B/s]

data/it-00000-of-00001-62485f87cf89f498.(…):   0%|          | 0.00/123M [00:00<?, ?B/s]

data/nl-00000-of-00002-ec9215c2bba222c1.(…):   0%|          | 0.00/85.5M [00:00<?, ?B/s]

data/nl-00001-of-00002-a8989d8ed3a39aaf.(…):   0%|          | 0.00/123M [00:00<?, ?B/s]

data/pl-00000-of-00001-68933bed4abc7dfd.(…):   0%|          | 0.00/2.45M [00:00<?, ?B/s]

data/pt-00000-of-00001-f6940f76b5585b13.(…):   0%|          | 0.00/73.9M [00:00<?, ?B/s]

data/ru-00000-of-00001-c0a99d21d8849748.(…):   0%|          | 0.00/343k [00:00<?, ?B/s]

data/sv-00000-of-00001-6d07271c1c00324d.(…):   0%|          | 0.00/40.5M [00:00<?, ?B/s]

data/zh-00000-of-00001-bca42330c6f1826c.(…):   0%|          | 0.00/120M [00:00<?, ?B/s]

Generating de split:   0%|          | 0/3131 [00:00<?, ? examples/s]

Generating en split:   0%|          | 0/61340 [00:00<?, ? examples/s]

In [33]:
#TODO: Evaluate on test set

def finish_prompt(prompt: str, model: BasicGPT, max_new_tokens: int = 256):
    model.eval()
    device = next(model.parameters()).device
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs.get("attention_mask", None),
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.9,
            eos_token_id=getattr(tokenizer, "eos_token_id", None),
        )

    return tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)

In [34]:
prompt = "There once was a man who "
print(finish_prompt(prompt, model, 128))

There once was a man who   tried
                                                                                     CHLEPHOES

                                   
