**Importing libraries and configurations**

In [None]:
import os, math, random, numpy as  np, torch
from datasets import load_dataset, DatasetDict
from transformers import (AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline)



SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

model_id   = "gpt2"          # or "distilgpt2" for faster runs
out_dir    = "runs/jokes_gpt2"  #output directory of runs
block_size = 256             # reduce if you hit OOM (e.g., 128/192)
train_bs   = 16              # reduce if OOM (8)
eval_bs    = 16
epochs     = 3
lr         = 5e-5
use_fp16   = torch.cuda.is_available()  # safe on your 4070

os.makedirs(out_dir, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
device

**Load dataset of Reddit jokes & create a single text fields**

In [None]:
ds = load_dataset("reddit_jokes")

def merge(rec):
    title = rec.get("title") or ""
    body  = rec.get("body") or ""
    txt = (title + "\n" + body).strip()
    return {"text": txt}

ds = ds.map(merge, remove_columns=ds["train"].column_names)
ds = ds.filter(lambda x: len(x["text"]) > 20)

# 90/10 train/val
split = ds["train"].train_test_split(test_size=0.1, seed=SEED)
ds = DatasetDict({"train": split["train"], "validation": split["test"]})
ds


**Tokenizer BPE & Tokenization**

In [None]:
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token  # GPT-2 has no pad token by default

def tokenize(batch):
    return tok(batch["text"], truncation=True, max_length=block_size)

tokenized = ds.map(tokenize, batched=True, remove_columns=ds["train"].column_names)
tokenized

**Load model + collator**

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_id)
model.resize_token_embeddings(len(tok))  # in case PAD was added

collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)


**Training setup + fine-tuning**

In [None]:
args = TrainingArguments(
    output_dir=out_dir,
    per_device_train_batch_size=train_bs,
    per_device_eval_batch_size=eval_bs,
    gradient_accumulation_steps=1,
    num_train_epochs=epochs,
    learning_rate=lr,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,
    save_total_limit=2,
    logging_steps=100,
    fp16=use_fp16,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    seed=SEED
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    data_collator=collator,
    tokenizer=tok,
)

trainer.train()


**Evaluate → Perplexity**

In [None]:
eval_res = trainer.evaluate()
ppl = math.exp(eval_res["eval_loss"])
print(eval_res)
print(f"Perplexity: {ppl:.2f}")

**Save checkpoint**

In [None]:
trainer.save_model(out_dir)       # saves config + tokenizer + model weights
tok.save_pretrained(out_dir)

torch.save(model.state_dict(), os.path.join(out_dir, "pytorch_model_weights_only.pt"))
print("Saved to:", out_dir)

**Generation helper (sampling: temperature / top-k / top-p)**

In [None]:
gen = pipeline("text-generation", model=out_dir, tokenizer=tok, device=0 if device=="cuda" else -1)

def generate(prompt, max_new=80, temperature=0.9, top_k=50, top_p=0.9, n=3):
    outs = gen(
        prompt,
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        max_new_tokens=max_new,
        num_return_sequences=n,
        pad_token_id=tok.eos_token_id,
    )
    for i, o in enumerate(outs, 1):
        print(f"\n=== Sample {i} ===\n{o['generated_text']}\n")

# Try a few prompts
generate("Why did the chicken cross the road?")
generate("My lecturer said:", n=2)
generate("An engineering student walks into a lab and", n=2)


**Samples**

In [None]:
samples_path = os.path.join(out_dir, "samples.txt")

with open(samples_path, "w", encoding="utf-8") as f:
    prompts = [
        "Why did the chicken cross the road?",
        "Write a one-line dad joke about GPUs:",
        "A software engineer and a hardware engineer walk into a bar and"
    ]
    for p in prompts:
        outs = gen(p, do_sample=True, top_k=50, top_p=0.9, temperature=0.9, max_new_tokens=80, num_return_sequences=3,
                   pad_token_id=tok.eos_token_id)
        f.write(f"\n\n# Prompt: {p}\n")
        for i, o in enumerate(outs, 1):
            f.write(f"\n[{i}] {o['generated_text']}\n")
samples_path
