In [None]:
!pip install torch transformers datasets huggingface_hub

In [None]:
import torch
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import load_dataset, interleave_datasets
from huggingface_hub import login

In [None]:
login()

In [None]:
!gcloud auth login

In [None]:
#ds = load_dataset("wikitext", "wikitext-2-raw-v1")
# Streaming mC4 with English bias (50% English, 1 epoch target)
TOTAL_ROWS = 2_000_000       # Updated to 1m
ENGLISH_ROWS = 1_000_000     # 50%
OTHER_ROWS = 10_000         # ~1m / 100 languages

ENGLISH_VAL_ROWS = 10_000  # Scaled down validation
OTHER_VAL_ROWS = 100

LANGUAGES = [
    "af", "am", "ar", "az", "be", "bg", "bn", "ca", "ceb", "co", "cs", "cy", "da", "de",
    "el", "eo", "es", "et", "eu", "fa", "fi", "fil", "fr", "fy", "ga", "gd", "gl", "gu",
    "ha", "haw", "iw", "hi", "hmn", "ht", "hu", "hy", "id", "ig", "is", "it", "ja",
    "jv", "ka", "kk", "km", "kn", "ko", "ku", "ky", "la", "lb", "lo", "lt", "lv", "mg",
    "mi", "mk", "ml", "mn", "mr", "ms", "mt", "my", "ne", "nl", "no", "ny", "pa", "pl",
    "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", "sm", "sn", "so", "sq", "sr", "st",
    "su", "sv", "sw", "ta", "te", "tg", "th", "tr", "uk", "und", "ur", "uz", "vi", "xh",
    "yi", "yo", "zh", "zu"
 ]

en_train = load_dataset("allenai/c4", "en", split="train", streaming=True).select_columns(["text"])
en_val = load_dataset("allenai/c4", "en", split="validation", streaming=True).select_columns(["text"])


other_train_list = [
    load_dataset("allenai/c4", lang, split="train", streaming=True).take(OTHER_ROWS).select_columns(["text"])
    for lang in LANGUAGES
]
other_val_list = [
    load_dataset("allenai/c4", lang, split="validation", streaming=True).take(OTHER_VAL_ROWS).select_columns(["text"])
    for lang in LANGUAGES
]

train_parts = [en_train.take(ENGLISH_ROWS)] + other_train_list
val_parts = [en_val.take(ENGLISH_VAL_ROWS)] + other_val_list

train_ds = interleave_datasets(train_parts, seed=42, stopping_strategy="all_exhausted")
val_ds = interleave_datasets(val_parts, seed=42, stopping_strategy="all_exhausted")

SPECIAL_TOKENS_LIST = ["<pad>", "<s>", "</s>", "<unk>"]
SPECIAL_TOKENS = {tok: i for i, tok in enumerate(SPECIAL_TOKENS_LIST)}
offset = len(SPECIAL_TOKENS_LIST)

pad_token_id = SPECIAL_TOKENS["<pad>"]
bos_token_id = SPECIAL_TOKENS["<s>"]
eos_token_id = SPECIAL_TOKENS["</s>"]
unk_token_id = SPECIAL_TOKENS["<unk>"]

MAX_SEQ_LEN = 256

# Pure-Python character split (mirrors AllenNLP when byte_encoding=None).
def chars_from_text(text):
    return list(text)

In [None]:
print("Checking the first 10 samples of the interleaved stream:\n")
# We take 10 samples to see the mixing pattern
for i, example in enumerate(train_ds.take(30)):
    # Print the first 80 chars to identify the language
    text_sample = example['text'][:80].replace('\n', ' ')
    print(f"Sample {i+1}: {text_sample}...")

In [None]:
def build_char_vocab(dataset):
    char_set = set()
    for example in dataset:
        for ch in chars_from_text(example["text"]):
            char_set.add(ch)
    id2char = sorted(char_set)
    char2id = {ch: i + offset for i, ch in enumerate(id2char)}
    return char2id, id2char

print("Building character vocab...")
# Use a subset to build vocab quickly (scanning 4.5M rows would take hours)
# 400,000 samples should be enough to cover characters from all 100+ languages
char2id, id2char = build_char_vocab(train_ds.take(400000))
vocab_size = offset + len(id2char)
print(f"Vocab size: {vocab_size} (chars: {len(id2char)})")

In [None]:
import gc

# Free up GPU memory from previous runs to prevent OOM
if 'model' in globals():
    del model
if 'trainer' in globals():
    del trainer
torch.cuda.empty_cache()
gc.collect()

config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=MAX_SEQ_LEN,
    n_embd=768,
    n_layer=12,
    n_head=12,
    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    pad_token_id=pad_token_id,
 )

model = GPT2LMHeadModel(config).to("cuda")
print(f"Total Parameters: {model.num_parameters():,}")

In [None]:
def chars_tokenize(examples):
    all_encoded = []
    for text in examples["text"]:
        tokens = chars_from_text(text)
        encoded = [char2id.get(ch, unk_token_id) for ch in tokens][:MAX_SEQ_LEN]
        all_encoded.append(encoded)
    return {"input_ids": all_encoded}

In [None]:
tokenized_ds = train_ds.map(chars_tokenize, batched=True, remove_columns=["text"])

# FIX: Limit validation set to 512 samples so evaluation takes seconds, not hours
tokenized_val_ds = val_ds.take(512).map(chars_tokenize, batched=True, remove_columns=["text"])

In [None]:
import json
vocab_data = {'char2id': char2id, 'id2char': id2char}
with open('vocab.json', 'w', encoding='utf-8') as f:
    json.dump(vocab_data, f, ensure_ascii=False)
# Then download and copy to work/vocab.json
import json

vocab_data = {'char2id': char2id, 'id2char': id2char}
with open('vocab.json', 'w', encoding='utf-8') as f:
    json.dump(vocab_data, f, ensure_ascii=False)

BUCKET_NAME = "cse447"
print(f"Uploading vocab to gs://{BUCKET_NAME}/nano-char-gpt-c4-v2/vocab.json...")
!gcloud storage cp ./vocab.json gs://{BUCKET_NAME}/nano-char-gpt-c4-v2/vocab.json
print("Vocab upload complete!")

In [None]:
BATCH_SIZE = 32
MAX_STEPS = TOTAL_ROWS // BATCH_SIZE

training_args = TrainingArguments(
    output_dir="./nano-char-gpt-c4-v2",
    per_device_train_batch_size=BATCH_SIZE,
    max_steps=MAX_STEPS,   # Auto-calculated
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_steps=500,
    weight_decay=0.1,

    # L4 Specific Speed Boosts
    bf16=True,             # Faster/more stable than fp16 on L4
    tf32=True,             # Speed up internal math
    dataloader_num_workers=1, # Set to 1 to avoid warning with streaming datasets

    save_total_limit=1,
    logging_steps=100,
    eval_strategy="steps", # Updated from evaluation_strategy
    eval_steps=100,
    report_to="none"
  )


def manual_collator(features):
    # Find the longest sequence in the batch
    max_len = max(len(f["input_ids"]) for f in features)

    batch_input_ids = []
    batch_attention_mask = []
    batch_labels = []

    for f in features:
        pad_len = max_len - len(f["input_ids"])
        padded_ids = f["input_ids"] + [pad_token_id] * pad_len
        padded_mask = [1] * len(f["input_ids"]) + [0] * pad_len
        labels = f["input_ids"] + [-100] * pad_len

        batch_input_ids.append(padded_ids)
        batch_attention_mask.append(padded_mask)
        batch_labels.append(labels)

    return {
        "input_ids": torch.tensor(batch_input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(batch_attention_mask, dtype=torch.long),
        "labels": torch.tensor(batch_labels, dtype=torch.long)
    }

# Update your trainer to use this collator
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds,
    eval_dataset=tokenized_val_ds,
    data_collator=manual_collator # Use our custom function
  )

model = torch.compile(model)
trainer.train()

trainer.save_model("./nano-char-gpt-c4-v2")

BUCKET_NAME = "cse447"

print(f"Uploading model to gs://{BUCKET_NAME}...")
!gcloud storage cp -r ./nano-char-gpt-c4-v2 gs://{BUCKET_NAME}/nano-char-gpt-c4-v2/
print("Upload Complete!")

In [None]:
import time

# Retry loop to handle streaming/network connection drops
MAX_RETRIES = 100
retry_count = 0

while retry_count < MAX_RETRIES:
    try:
        print(f"\n=== Training Attempt {retry_count + 1} ===")
        # resume_from_checkpoint=True automatically picks the latest checkpoint in output_dir
        trainer.train(resume_from_checkpoint=True)
        print("Training completed successfully!")
        break
    except Exception as e:
        print(f"\nTraining interrupted with error: {e}")
        print("Wait 30 seconds before retrying to clear connection issues...")
        time.sleep(30)
        retry_count += 1

if retry_count == MAX_RETRIES:
    print("Max retries reached. Training failed.")
else:
    # Only save and upload if we finished successfully
    trainer.save_model("./nano-char-gpt-c4-v2")

    BUCKET_NAME = "cse447"

    print(f"Uploading model to gs://{BUCKET_NAME}...")
    !gcloud storage cp -r ./nano-char-gpt-c4-v2 gs://{BUCKET_NAME}/nano-char-gpt-c4-v2/
    print("Upload Complete!")

In [None]:
BUCKET_NAME = "cse447"
print("Downloading model from GCS...")
!mkdir -p ./nano-char-gpt-c4-v2-downloaded
!gcloud storage cp -r gs://{BUCKET_NAME}/nano-char-gpt-c4-v2/nano-char-gpt-c4-v2 ./nano-char-gpt-c4-v2-downloaded
print("Download complete!")

local_model_path = "./nano-char-gpt-c4-v2-downloaded/nano-char-gpt-c4-v2"
print("Loading model into memory...")
model = GPT2LMHeadModel.from_pretrained(local_model_path).to("cuda")
print("Model loaded!")

In [None]:
def decode_id(tok_id):
    if tok_id < offset:
        return SPECIAL_TOKENS_LIST[tok_id]
    return id2char[tok_id - offset]


def encode_text(text):
    tokens = chars_from_text(text)
    return [char2id.get(ch, unk_token_id) for ch in tokens]


def predict_next_token_top3(text):
    encoded = encode_text(text)[-MAX_SEQ_LEN:]
    if not encoded:
        encoded = [bos_token_id]

    input_ids = torch.tensor([encoded], dtype=torch.long).to("cuda")

    with torch.no_grad():
        logits = model(input_ids).logits[:, -1, :] # Last position
        probs = torch.softmax(logits, dim=-1)

    top_k_probs, top_k_ids = torch.topk(probs, 3)

    for i in range(3):
        tok_id = top_k_ids[0][i].item()
        token_text = decode_id(tok_id)
        if token_text == " ":
            token_display = "Space"
        elif len(token_text) == 1 and ord(token_text) < 32:
            token_display = f"Control ({tok_id})"
        else:
            token_display = token_text
        print(f"Rank {i+1}: '{token_display}' - {top_k_probs[0][i]:.2%}")

In [None]:
predict_next_token_top3("one small step for mankind")