In [None]:
!nvidia-smi

!pip -q install -U transformers datasets peft bitsandbytes scikit-learn
!pip -q install -U evaluate textstat accelerate

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from pathlib import Path
import re
from huggingface_hub import login
import unicodedata
import numpy as np
from sklearn.model_selection import KFold
from datasets import Dataset, DatasetDict
import math, torch, gc
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, PeftModel, get_peft_model

In [None]:
login()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()

In [None]:
sonnets_txt_path = Path("/content/drive/MyDrive/bardify/shakespeares-sonnets_TXT_FolgerShakespeare.txt")
sonnets_txt = sonnets_txt_path.read_text(encoding="utf-8", errors="ignore")

# Cleaning the data

In [None]:
def remove_start_end(txt):
    start_to_delete = re.search(r"(?m)^\s*1\s*$", txt)
    txt = txt[start_to_delete.start():]
    txt = re.split(r"\bTwo\s+Sonnets\s+from\s+The\s+Passionate\s+Pilgrim\b",
                   txt, flags=re.I)[0]
    return txt

removed_st_end = remove_start_end(sonnets_txt)

In [None]:
def normalise_indents(txt):
    txt = re.sub(r"\r\n?", "\n", txt)
    txt = re.sub(r"\n{3,}", "\n\n", txt).strip()
    txt = re.sub(r"(?m)^[ \t]+", "", txt)
    txt = re.sub(r"(?m)[ \t]+$", "", txt)
    return txt

normalised = normalise_indents(removed_st_end)

In [None]:
dash = "\u2014"

def hyphen_to_dash(txt):
    txt = unicodedata.normalize("NFC", txt)
    txt = re.sub(r'(?<=\S)--(?=\S)', dash, txt)
    txt = re.sub(r'\s--\s', f' {dash} ', txt)
    txt = re.sub(r'(?m)--\s*$', f' {dash}', txt)
    txt = re.sub(r'(?m)^(?=\S)--', dash, txt)
    return txt

dashes_added = hyphen_to_dash(normalised)

In [None]:
def separate_sonnets(txt):
    sonnets = []
    curr_num = None
    curr_block = []

    def add_curr_sonnet():
        if curr_num is not None and curr_block:
            body = "\n".join(curr_block).strip()
            if body:
                sonnets.append((curr_num, body))

    lines = dashes_added.splitlines()
    for line in lines:
        new_heading = re.match(r"^\s*(\d{1,3})\s*$", line)
        if new_heading:
            add_curr_sonnet()
            curr_num = int(new_heading.group(1))
            curr_block = []
        else:
            if curr_num is not None:
                line = re.sub(r"\s+", " ", line).strip()
                curr_block.append(line)

    add_curr_sonnet()
    return sonnets

separated = separate_sonnets(dashes_added)

In [None]:
separated

In [None]:
corpus = "\n\n".join([sonnet for _, sonnet in separated]).strip()
print(f"Corpus size (chars): {len(corpus)}")

# Cleaned sonnets to file

In [None]:
Path("/content/corpus").mkdir(parents=True, exist_ok=True)
(Path("/content/corpus") / "all_sonnets.txt").write_text(corpus, encoding="utf-8")

print("Saved:", "/content/corpus/all_sonnets.txt")

# Train/test split (K-Fold)

In [None]:
sonnets = [sonnet for _, sonnet in separated]
idx = np.arange(len(sonnets))
kf = KFold(n_splits=10, shuffle=True, random_state=42)

BASE = Path("/content/corpus_cv")
BASE.mkdir(parents=True, exist_ok=True)

fold_paths = []
for fold, (tr, va) in enumerate(kf.split(idx), 1):
    fdir = BASE / f"f{fold}"
    fdir.mkdir(parents=True, exist_ok=True)
    (fdir / "train.txt").write_text("\n\n".join(sonnets[i] for i in tr), encoding="utf-8")
    (fdir / "valid.txt").write_text("\n\n".join(sonnets[i] for i in va), encoding="utf-8")
    fold_paths.append((str(fdir / "train.txt"), str(fdir / "valid.txt")))

print(f"Prepared such folds:\n")
for fold_path in fold_paths:
    print(f"{fold_path}\n")

# Tokenisation

In [None]:
train_model = "meta-llama/Llama-3.1-8B"

tokeniser = AutoTokenizer.from_pretrained(
    train_model,
    use_fast=False)

if tokeniser.pad_token is None:
    tokeniser.pad_token = tokeniser.eos_token

In [None]:
def token_len(text):
    return len(
        tokeniser(text, add_special_tokens=False)["input_ids"]
    )

lengths = [(num, token_len(txt)) for num, txt in separated]
max_sonnet_num, max_token_seq = max(lengths, key=lambda x: x[1])

print(f"Max token sequence length is {max_token_seq} (Sonnet {max_sonnet_num})")

In [None]:
BLOCK_SIZE = 512

collator = DataCollatorForLanguageModeling(
    tokenizer=tokeniser,
    mlm=False)

In [None]:
def read_poems(path: str):
    text = Path(path).read_text(encoding="utf-8").strip()
    return [p for p in text.split("\n\n") if p.strip()]

In [None]:
def get_folds(fold_k):
    train_path, valid_path = fold_paths[fold_k - 1]
    train_docs = read_poems(train_path)
    valid_docs = read_poems(valid_path)

    raw = DatasetDict({
        "train": Dataset.from_dict({"text": train_docs}),
        "validation": Dataset.from_dict({"text": valid_docs}),
    })

    def tokenize_chunk(batch):
        texts = [
            t + tokeniser.eos_token
            for t in batch["text"]
        ]

        enc = tokeniser(
            texts,
            add_special_tokens=False,
            truncation=True,
            max_length=BLOCK_SIZE,
            return_overflowing_tokens=True,
            return_attention_mask=False,
        )

        return {"input_ids": enc["input_ids"]}

    return raw.map(
        tokenize_chunk,
        batched=True,
        remove_columns=["text"],
    )


# Training

In [None]:
FIXED_HPARAMS = {
    "lr": 5e-05,
    "r": 32,
    "alpha": 64,
    "dropout": 0.1,
    "batch_size": 1,
    "grad_accum": 8,
    "epochs": 3,
}

In [None]:
def train_all_folds(h, n_folds=10, out_root="/content/bardify_cv"):
    out_root = Path(out_root)
    out_root.mkdir(parents=True, exist_ok=True)

    fold_ppls = []
    fold_losses = []

    for fold_k in range(1, n_folds + 1):
        print(f"\n")
        print(f"Training fold {fold_k}/{n_folds}")

        lm_dsets = get_folds(fold_k)

        bnb_cfg = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16 if use_bf16 else torch.float16,
            bnb_4bit_use_double_quant=True,
        )

        model = AutoModelForCausalLM.from_pretrained(
            train_model,
            quantization_config=bnb_cfg,
            device_map="auto",
        )

        lora_cfg = LoraConfig(
            r=h["r"],
            lora_alpha=h["alpha"],
            lora_dropout=h["dropout"],
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
            bias="none",
            task_type="CAUSAL_LM",
            fan_in_fan_out=False,
        )

        model = get_peft_model(model, lora_cfg)
        model.print_trainable_parameters()

        fold_out = out_root / f"fold_{fold_k}"
        fold_out.mkdir(parents=True, exist_ok=True)

        args = TrainingArguments(
            output_dir=str(fold_out / "trainer_out"),

            per_device_train_batch_size=h["batch_size"],
            per_device_eval_batch_size=h["batch_size"],
            gradient_accumulation_steps=h["grad_accum"],

            eval_strategy="steps",
            eval_steps=200,
            logging_strategy="steps",
            logging_steps=100,

            save_strategy="steps",
            save_steps=200,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,

            learning_rate=h["lr"],
            num_train_epochs=h["epochs"],

            warmup_ratio=0.1,
            lr_scheduler_type="cosine",
            weight_decay=0.01,

            bf16=use_bf16,
            fp16=not use_bf16 and device == "cuda",

            report_to=[],
            remove_unused_columns=False,
        )

        trainer = Trainer(
            model=model,
            args=args,
            train_dataset=lm_dsets["train"],
            eval_dataset=lm_dsets["validation"],
            data_collator=collator,
        )

        trainer.train()
        ev = trainer.evaluate()

        loss = float(ev["eval_loss"])
        ppl = float(math.exp(loss))

        fold_losses.append(loss)
        fold_ppls.append(ppl)

        print(f"\nFinal PPL (fold {fold_k}): {ppl:.3f}  (eval_loss={loss:.3f})")

        (fold_out / "lora").mkdir(parents=True, exist_ok=True)
        model.save_pretrained(str(fold_out / "lora"))
        tokeniser.save_pretrained(str(fold_out / "lora"))

        del trainer, model
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    mean_ppl = float(np.mean(fold_ppls))
    std_ppl  = float(np.std(fold_ppls, ddof=1)) if len(fold_ppls) > 1 else 0.0

    print("\n")
    print("CV summary")
    for k, ppl in enumerate(fold_ppls, 1):
        print(f"fold {k}: PPL={ppl:.3f}")
    print(f"\nMean PPL: {mean_ppl:.3f}  |  Std: {std_ppl:.3f}")

    return {
        "ppl_by_fold": fold_ppls,
        "loss_by_fold": fold_losses,
        "mean_ppl": mean_ppl,
        "std_ppl": std_ppl,
        "out_root": str(out_root),
    }

In [None]:
ppls = train_all_folds(FIXED_HPARAMS, n_folds=10, out_root="/content/bardify_cv")["ppl_by_fold"]

# Final training

In [None]:
BASE_MODEL_ID = train_model
LORA_PATH = "/content/bardify_final_model/lora"

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 if use_bf16 else torch.float16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(LORA_PATH, use_fast=False)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    quantization_config=bnb_cfg,
    device_map={"": 0},
)

model = PeftModel.from_pretrained(base, LORA_PATH)
model.eval()

# Generation

In [None]:
def build_sonnet_prompt(topic):
    return f"""Write a Shakespearean sonnet about: {topic}

Constraints:
- Exactly 14 lines
- Rhyme scheme: ABAB CDCD EFEF GG
- Shakespearean diction (thee, thou, thy, art)
- Iambic pentameter (approximate)
- Output ONLY the poem

Sonnnet:
"""

@torch.inference_mode()
def generate_sonnet(topic):
    prompt = build_sonnet_prompt(topic)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    out = model.generate(
        **inputs,
        max_new_tokens=260,
        do_sample=True,
        temperature=0.9,
        top_p=0.95,
        repetition_penalty=1.1,
        no_repeat_ngram_size=3,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )

    text = tokenizer.decode(out[0], skip_special_tokens=True)
    poem = text[len(prompt):].strip()
    lines = [l.strip() for l in poem.splitlines() if l.strip()][:14]
    return "\n".join(lines)

print(generate_sonnet("England"))