# EmoBERTa-style RoBERTa on IEMOCAP (6-way)

This notebook trains a **RoBERTa** classifier on **IEMOCAP 6-way** labels using an **EmoBERTa-style context string** where the **target utterance is surrounded by exactly two `</s>` tokens**.

## What you should add for a GitHub repo
- Put your CSVs under `data/` (not `/content/...`) and keep data **out of git** (see `.gitignore`).
- Add `requirements.txt` (or `environment.yml`) instead of `!pip install ...`.
- Add a `README.md` explaining:
  - how to obtain IEMOCAP (license restrictions apply),
  - how to create the train/val/test CSVs,
  - how to run training and reproduce results.


## 0. Imports & environment

In [None]:
# If you're running this from a clean environment, install deps once:
#   pip install -r requirements.txt
#
# (In Colab you can alternatively run: !pip install -r requirements.txt)

import os
import random
import shutil
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import optuna

from datasets import Dataset
from sklearn.metrics import accuracy_score, f1_score
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
    TrainerCallback,
    set_seed,
)

# Optional: version printout for reproducibility
import transformers, datasets
print("transformers:", transformers.__version__)
print("datasets:", datasets.__version__)
print("torch:", torch.__version__)


## 1. Configuration

In [None]:
# =====================
# CONFIG (IEMOCAP 6-way)
# =====================

# Project layout (GitHub-friendly): keep data in ./data and outputs in ./outputs
PROJECT_ROOT = Path(".").resolve()
DATA_DIR = PROJECT_ROOT / "data"
OUTPUT_DIR = PROJECT_ROOT / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# CSV paths (expected under ./data)
TRAIN_CSV = DATA_DIR / "iemocap_emoberta_train.csv"
VAL_CSV   = DATA_DIR / "iemocap_emoberta_val.csv"
TEST_CSV  = DATA_DIR / "iemocap_emoberta_test.csv"

# Column names in your CSVs
DIALOG_COL  = "Dialogue_ID"
UTTID_COL   = "Utterance_ID"
SPEAKER_COL = "Speaker"     # "F" / "M"
TEXT_COL    = "Utterance"
LABEL_COL   = "Emotion"

# IEMOCAP 6 emotions
LABELS = ["neutral", "frustration", "sadness", "anger", "excited", "happiness"]
label2id = {l: i for i, l in enumerate(LABELS)}
id2label = {i: l for l, i in label2id.items()}

# Model
MODEL_BASE = "roberta-base"

# Paper-like constants
WEIGHT_DECAY = 0.01
EPOCHS = 5
WARMUP_RATIO = 0.20
LR_SCHED = "linear"

# Optuna: tune ONLY peak LR
N_TRIALS = 5
LR_LOW, LR_HIGH = 1e-6, 1e-4

# Training defaults
MAX_LEN = 512
BATCH_TRAIN = 8
BATCH_EVAL  = 16
GRAD_ACCUM  = 1

# Reproducibility / reporting
SEED = 42
SEEDS_FINAL = [42, 43, 44, 45, 46]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)
print("LABELS:", LABELS)

# ==========================================
# EmoBERTa-style speaker names (IEMOCAP)
# Map "actor id" = SesXX + (F/M) -> name.
# ==========================================
NAME_MAP = {
    "Ses01F": "MARY",      "Ses02F": "PATRICIA", "Ses03F": "JENNIFER", "Ses04F": "LINDA",   "Ses05F": "ELIZABETH",
    "Ses01M": "JAMES",     "Ses02M": "JOHN",     "Ses03M": "ROBERT",   "Ses04M": "MICHAEL", "Ses05M": "WILLIAM",
}

# Quick sanity check (helps avoid confusing FileNotFound errors in GitHub runs)
for p in [TRAIN_CSV, VAL_CSV, TEST_CSV]:
    if not p.exists():
        print(f"⚠️ Missing file: {p} (put your CSVs under ./data)")


## 2. Load data and filter labels

In [None]:
# ==========================
# Load + filter to IEMOCAP-6
# ==========================

# Map common label variants -> canonical names (helps if your CSV uses abbreviations)
LABEL_MAP = {
    "neu": "neutral", "neutral": "neutral",
    "fru": "frustration", "frustrated": "frustration", "frustration": "frustration",
    "sad": "sadness", "sadness": "sadness",
    "ang": "anger", "anger": "anger",
    "exc": "excited", "excited": "excited",
    "hap": "happiness", "happy": "happiness", "happiness": "happiness",
}

def load_and_filter_iemocap6(path: Path) -> pd.DataFrame:
    """Load a split CSV and keep only the 6 IEMOCAP labels used in this project."""
    print(f"--- Processing {path} ---")
    df = pd.read_csv(path)
    print("Original shape:", df.shape)
    print("Columns:", df.columns.tolist())

    # Normalize types / casing
    df[TEXT_COL] = df[TEXT_COL].astype(str)
    df[SPEAKER_COL] = df[SPEAKER_COL].astype(str).str.strip().str.upper()
    df[UTTID_COL] = df[UTTID_COL].astype(str)
    df[DIALOG_COL] = df[DIALOG_COL].astype(str)

    df[LABEL_COL] = (
        df[LABEL_COL].astype(str).str.strip().str.lower().replace(LABEL_MAP)
    )

    # Keep only the 6 labels
    df = df[df[LABEL_COL].isin(LABELS)].copy()
    print("After label filtering:", df.shape)
    print("Label counts:\n", df[LABEL_COL].value_counts())

    return df

train_df = load_and_filter_iemocap6(TRAIN_CSV)
val_df   = load_and_filter_iemocap6(VAL_CSV)
test_df  = load_and_filter_iemocap6(TEST_CSV)

print("Rows:", len(train_df), len(val_df), len(test_df))


## 3. Reorder rows into turn order

In [None]:
def reorder_iemocap_csv(
    df: pd.DataFrame,
    dialog_col: str = DIALOG_COL,
    uttid_col: str = UTTID_COL,
    spk_col: str = SPEAKER_COL,
) -> pd.DataFrame:
    """Sort utterances inside each dialogue in true turn order.

    Assumes IEMOCAP-style Utterance_ID like `..._F003` or `..._M011`.
    Uses the dialogue starter (from Dialogue_ID prefix `SesXXF` / `SesXXM`)
    to break ties when two rows share the same numeric turn index.
    """
    df = df.copy()
    df[dialog_col] = df[dialog_col].astype(str)
    df[uttid_col]  = df[uttid_col].astype(str)
    df[spk_col]    = df[spk_col].astype(str).str.strip().str.upper()

    # numeric index from ..._F003 / ..._M011
    df["_idx"] = df[uttid_col].str.extract(r"_[FM](\d+)$")[0].astype(int)

    # starter speaker from Dialogue_ID like Ses01F_impro01 or Ses03M_...
    df["_starter"] = df[dialog_col].str.extract(r"^Ses\d{2}([FM])")[0].fillna("F").str.upper()
    df["_prio"] = (df[spk_col] != df["_starter"]).astype(int)  # 0 for starter, 1 for other

    df = df.sort_values([dialog_col, "_idx", "_prio"]).reset_index(drop=True)
    return df.drop(columns=["_idx", "_starter", "_prio"])


In [None]:
train_df = reorder_iemocap_csv(train_df)
val_df   = reorder_iemocap_csv(val_df)
test_df  = reorder_iemocap_csv(test_df)


## 4. Tokenizer, collator, metrics

In [None]:
tok = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=True)
collator = DataCollatorWithPadding(tokenizer=tok)

def compute_metrics(eval_pred):
    """Compatibility-friendly metrics (works across Trainer versions)."""
    preds, labels = eval_pred
    # Some HF versions return a tuple(preds,) for logits
    if isinstance(preds, (tuple, list)):
        preds = preds[0]
    y_pred = np.argmax(preds, axis=1)
    return {
        "acc": accuracy_score(labels, y_pred),
        "weighted_f1": f1_score(labels, y_pred, average="weighted"),
        "macro_f1": f1_score(labels, y_pred, average="macro"),
    }

print("CLS:", tok.cls_token, tok.cls_token_id, "SEP:", tok.sep_token, tok.sep_token_id)


## 5. Build EmoBERTa-style context (target `</s>` only)

In [None]:
# ==========================
# TARGET-SEP-ONLY context builder (with SPACES around </s>)
#   - EXACTLY two </s>: before and after TARGET
#   - NO </s> between past/future utterances
#   - Adds spaces so tokenizer sees </s> cleanly (prevents token "glue")
#   - Target includes speaker name too
# ==========================

def build_context_dataset_target_sep_only(
    df: pd.DataFrame,
    tokenizer,
    max_length: int = 512,
    speaker_caps: bool = True,
    debug_n: int = 3,
    insert_space_between_utts: bool = True,  # readability WITHOUT adding </s>
    include_raw_text: bool = True,
    name_map: dict | None = None,
) -> Dataset:
    """Create a HF Dataset with pre-tokenized `input_ids` / `attention_mask` / `labels`.

    Context format:
        LEFT_CONTEXT + [</s>] + TARGET + [</s>] + RIGHT_CONTEXT

    Notes for GitHub:
    - Avoid relying on `globals()`; pass `name_map` explicitly (defaults to NAME_MAP).
    - Consider moving this function into `src/preprocess.py` so it can be imported by scripts/tests.
    """
    df = df.copy()
    name_map = NAME_MAP if name_map is None else name_map

    # -------- normalize --------
    df[TEXT_COL] = df[TEXT_COL].astype(str)
    df[SPEAKER_COL] = df[SPEAKER_COL].astype(str).str.strip().str.upper()
    df[LABEL_COL] = df[LABEL_COL].astype(str).str.strip().str.lower().replace(LABEL_MAP)

    # Keep only wanted labels
    df = df[df[LABEL_COL].isin(LABELS)].copy()

    # -------- ordering (IEMOCAP-style if possible, else numeric) --------
    # If Utterance_ID ends with _F003 / _M011, we can get a reliable numeric turn index.
    turn_ex = df[UTTID_COL].astype(str).str.extract(r"_[FM](\d+)$")[0]
    if turn_ex.notna().all():
        df["_turn"] = turn_ex.astype(int)
        df["_starter"] = df[DIALOG_COL].astype(str).str.extract(r"^Ses\d{2}([FM])")[0].fillna("F").str.upper()
        df["_prio"] = (df[SPEAKER_COL] != df["_starter"]).astype(int)
        df = df.sort_values([DIALOG_COL, "_turn", "_prio"]).reset_index(drop=True)
    else:
        # Fallback: numeric cast (only if your Utterance_ID is already numeric)
        df[UTTID_COL] = pd.to_numeric(df[UTTID_COL], errors="coerce")
        df = df.dropna(subset=[DIALOG_COL, UTTID_COL]).copy()
        df[UTTID_COL] = df[UTTID_COL].astype(int)
        df = df.sort_values([DIALOG_COL, UTTID_COL]).reset_index(drop=True)

    # -------- speaker names (EmoBERTa NAME_MAP if possible; else use SPEAKER) --------
    # NOTE: session is derived from Dialogue_ID (more robust than from Utterance_ID).
    df["_session"] = df[DIALOG_COL].astype(str).str.extract(r"^(Ses\d{2})")[0]
    df["_actor"] = (df["_session"].fillna("UNK") + df[SPEAKER_COL])
    df["_name"] = df["_actor"].map(name_map).fillna(df[SPEAKER_COL])

    if speaker_caps:
        df["_name"] = df["_name"].astype(str).str.upper()

    cls_id = tokenizer.cls_token_id  # <s>
    sep_id = tokenizer.sep_token_id  # </s>

    # Reserve CLS only (we're manually building the rest)
    max_tokens = max_length - 1

    all_input_ids, all_attn, all_labels = [], [], []
    all_texts, all_dialog, all_turn = [], [], []

    dbg_printed = 0
    lengths = []
    sep_counts = []

    # Precompute encodings with/without leading space:
    # - First utterance has no leading space
    # - Subsequent utterances get a leading space (if insert_space_between_utts=True)
    def enc_no_space(x): return tokenizer.encode(x, add_special_tokens=False)
    def enc_with_space(x): return tokenizer.encode(" " + x, add_special_tokens=False)

    for d_id, g in df.groupby(DIALOG_COL, sort=False):
        names = g["_name"].tolist()
        utts  = g[TEXT_COL].tolist()
        labs  = g[LABEL_COL].tolist()
        turns = g[UTTID_COL].tolist()

        seg_text = [f"{nm}: {u}" for nm, u in zip(names, utts)]
        seg_ids0 = [enc_no_space(x) for x in seg_text]
        seg_ids1 = [enc_with_space(x) for x in seg_text] if insert_space_between_utts else seg_ids0

        n = len(seg_text)

        for t in range(n):
            target_text = seg_text[t]
            target_ids  = seg_ids0[t][:]

            # Must fit: [SEP] + target + [SEP]
            base = 2 + len(target_ids)
            if base > max_tokens:
                # Truncate target if it's too long
                keep = max(0, max_tokens - 2)
                target_ids = target_ids[:keep]
                base = 2 + len(target_ids)

            left_idxs, right_idxs = [], []
            left_len = 0
            right_len = 0

            # Expand context symmetrically outward (t-1, t+1, t-2, t+2, ...)
            i = 0
            while True:
                changed = False
                i += 1

                li = t - i
                if li >= 0:
                    add_len = len(seg_ids0[li]) if len(left_idxs) == 0 else len(seg_ids1[li])
                    if base + left_len + add_len + right_len <= max_tokens:
                        left_idxs.insert(0, li)
                        left_len += add_len
                        changed = True

                ri = t + i
                if ri < n:
                    add_len = len(seg_ids0[ri]) if len(right_idxs) == 0 else len(seg_ids1[ri])
                    if base + left_len + right_len + add_len <= max_tokens:
                        right_idxs.append(ri)
                        right_len += add_len
                        changed = True

                if not changed:
                    break
                if li < 0 and ri >= n:
                    break

            # LEFT ids 
            left_ids = []
            for k, idx in enumerate(left_idxs):
                left_ids += (seg_ids0[idx] if k == 0 else seg_ids1[idx])

            # RIGHT ids 
            right_ids = []
            for k, idx in enumerate(right_idxs):
                right_ids += (seg_ids0[idx] if k == 0 else seg_ids1[idx])

            # Final seq: LEFT + [SEP] + TARGET + [SEP] + RIGHT
            seq_ids = left_ids + [sep_id] + target_ids + [sep_id] + right_ids
            seq_ids = seq_ids[:max_tokens]

            input_ids = [cls_id] + seq_ids
            input_ids = input_ids[:max_length]

            all_input_ids.append(input_ids)
            all_attn.append([1] * len(input_ids))
            all_labels.append(label2id[labs[t]])
            all_dialog.append(d_id)
            all_turn.append(turns[t])

            # Optional raw text (debug / inspection). We keep EXACTLY 2 </s> and add spaces around them.
            if include_raw_text:
                left_raw  = (" ".join([seg_text[i] for i in left_idxs]).strip())
                right_raw = (" ".join([seg_text[i] for i in right_idxs]).strip())
                raw = f"<s> {left_raw} </s> {target_text} </s> {right_raw}".strip()
                raw = " ".join(raw.split())
                all_texts.append(raw)

            lengths.append(len(input_ids))
            sep_counts.append(int(np.sum(np.array(input_ids) == sep_id)))

            #
            if dbg_printed < debug_n:
                print("=" * 90)
                print(f"DEBUG {dbg_printed+1} | dialog={d_id} | target_turn={turns[t]} | label={labs[t]}")
                print(f"Left utts: {len(left_idxs)} | Right utts: {len(right_idxs)} | SEP count in input_ids: {sep_counts[-1]}")
                if include_raw_text:
                    parts = all_texts[-1].split("</s>")
                    print("\nRAW split:")
                    print("PAST   :", parts[0].replace("<s>", "").strip()[:220])
                    print("CURRENT:", (parts[1].strip() if len(parts) > 1 else "")[:220])
                    print("FUTURE :", (parts[2].strip() if len(parts) > 2 else "")[:220])
                print("\nDECODED (first 140 tokens):")
                print(tokenizer.decode(input_ids[:140], skip_special_tokens=False))
                dbg_printed += 1

    print("\nToken length stats:",
          f"min={int(np.min(lengths))}, mean={float(np.mean(lengths)):.1f}, max={int(np.max(lengths))}, n={len(lengths)}")
    print("SEP counts stats:",
          f"min={int(np.min(sep_counts))}, mean={float(np.mean(sep_counts)):.2f}, max={int(np.max(sep_counts))}")

    data = {
        "dialogue_id": all_dialog,
        "utterance_id": all_turn,
        "input_ids": all_input_ids,
        "attention_mask": all_attn,
        "labels": all_labels,
    }
    if include_raw_text:
        data["context_text_raw"] = all_texts

    return Dataset.from_dict(data)


def save_constructed_csv(ds: Dataset, out_csv: Path, id2label: dict | None = None) -> None:
    """Save a human-inspectable CSV of the constructed context strings."""
    d = ds.to_dict()
    df_out = pd.DataFrame({
        "dialogue_id": d["dialogue_id"],
        "utterance_id": d["utterance_id"],
        "label_id": d["labels"],
        "label": [id2label.get(int(x), str(x)) if isinstance(id2label, dict) else str(x) for x in d["labels"]],
        "context_text_raw": d.get("context_text_raw", [""] * len(d["labels"])),
    })
    out_csv.parent.mkdir(parents=True, exist_ok=True)
    df_out.to_csv(out_csv, index=False)
    print("✅ Saved:", out_csv, "| rows:", len(df_out))


# ----------- BUILD + SAVE -----------
train_ds_full = build_context_dataset_target_sep_only(
    train_df, tok, max_length=MAX_LEN, speaker_caps=True, debug_n=3,
    insert_space_between_utts=True, include_raw_text=True
)
val_ds_full   = build_context_dataset_target_sep_only(
    val_df, tok, max_length=MAX_LEN, speaker_caps=True, debug_n=1,
    insert_space_between_utts=True, include_raw_text=True
)
test_ds_full  = build_context_dataset_target_sep_only(
    test_df, tok, max_length=MAX_LEN, speaker_caps=True, debug_n=1,
    insert_space_between_utts=True, include_raw_text=True
)

print("Sizes:", len(train_ds_full), len(val_ds_full), len(test_ds_full))

save_constructed_csv(train_ds_full, OUTPUT_DIR / "train_constructed_targetSEPonly.csv", id2label=id2label)
save_constructed_csv(val_ds_full,   OUTPUT_DIR / "val_constructed_targetSEPonly.csv",   id2label=id2label)
save_constructed_csv(test_ds_full,  OUTPUT_DIR / "test_constructed_targetSEPonly.csv",  id2label=id2label)


## 6. Hyperparameter search (Optuna)

In [None]:
def objective(trial):
    """Optuna objective: minimize validation loss by tuning ONLY learning rate."""
    set_seed(SEED)

    lr = trial.suggest_float("lr", LR_LOW, LR_HIGH, log=True)

    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_BASE,
        num_labels=len(LABELS),
        label2id=label2id,
        id2label=id2label,
    ).to(DEVICE)

    args = TrainingArguments(
        output_dir=str(OUTPUT_DIR / f"optuna_lr_trial_{trial.number}"),
        evaluation_strategy="epoch",   # NOTE: 'eval_strategy' may break on some HF versions
        save_strategy="no",

        learning_rate=lr,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_TRAIN,
        per_device_eval_batch_size=BATCH_EVAL,
        gradient_accumulation_steps=GRAD_ACCUM,

        weight_decay=WEIGHT_DECAY,
        warmup_ratio=WARMUP_RATIO,
        lr_scheduler_type=LR_SCHED,

        fp16=torch.cuda.is_available(),
        report_to="none",
        seed=SEED,
        logging_steps=200,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds_full,
        eval_dataset=val_ds_full,
        data_collator=collator,
        tokenizer=tok,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    out = trainer.evaluate(val_ds_full)
    return out["eval_loss"]  # minimize cross-entropy


In [None]:
# Tip for GitHub reproducibility: seed Optuna's sampler
study = optuna.create_study(
    direction="minimize",
    sampler=optuna.samplers.TPESampler(seed=SEED),
)
study.optimize(objective, n_trials=N_TRIALS)

best_lr = study.best_params["lr"]
print("Best lr:", best_lr)
print("Best val loss:", study.best_value)


## 7. Final training across multiple seeds + evaluation

In [None]:
rows = []

# ---------- callback: save at end of each epoch ----------
class SaveByEpochCallback(TrainerCallback):
    """Extra per-epoch saving (separate from Trainer's own save_strategy)."""
    def __init__(self, out_root: Path, tokenizer):
        self.out_root = Path(out_root)
        self.tokenizer = tokenizer
        self.out_root.mkdir(parents=True, exist_ok=True)

    def on_epoch_end(self, args, state, control, **kwargs):
        model = kwargs["model"]
        ep = state.epoch
        ep_i = int(round(ep)) if ep is not None else 0

        save_dir = self.out_root / f"epoch_{ep_i:02d}"
        save_dir.mkdir(parents=True, exist_ok=True)

        model.save_pretrained(save_dir)
        self.tokenizer.save_pretrained(save_dir)
        print(f"✅ Saved epoch checkpoint to: {save_dir}")
        return control


for seed in SEEDS_FINAL:
    print("\n" + "=" * 20, "SEED", seed, "=" * 20)
    set_seed(seed)

    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_BASE,
        num_labels=len(LABELS),
        label2id=label2id,
        id2label=id2label,
    ).to(DEVICE)

    out_dir = OUTPUT_DIR / f"roberta_iemocap_final_seed{seed}"

    # Keep epoch checkpoints separate (so Trainer's checkpoint cleanup doesn't delete them)
    epoch_root = OUTPUT_DIR / f"epoch_checkpoints_seed{seed}"
    if epoch_root.exists():
        shutil.rmtree(epoch_root)
    epoch_root.mkdir(parents=True, exist_ok=True)

    epoch_saver = SaveByEpochCallback(epoch_root, tok)

    args = TrainingArguments(
        output_dir=str(out_dir),
        evaluation_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,

        load_best_model_at_end=True,
        metric_for_best_model="weighted_f1",
        greater_is_better=True,

        learning_rate=best_lr,
        # NOTE: You used 7 here originally; consider moving this to config so it's not "magic".
        num_train_epochs=7,
        per_device_train_batch_size=BATCH_TRAIN,
        per_device_eval_batch_size=BATCH_EVAL,
        gradient_accumulation_steps=GRAD_ACCUM,

        weight_decay=WEIGHT_DECAY,
        warmup_ratio=WARMUP_RATIO,
        lr_scheduler_type=LR_SCHED,

        fp16=torch.cuda.is_available(),
        report_to="none",
        seed=seed,
        logging_steps=200,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds_full,
        eval_dataset=val_ds_full,
        data_collator=collator,
        tokenizer=tok,
        compute_metrics=compute_metrics,
        callbacks=[epoch_saver],
    )

    trainer.train()

    best_ckpt = trainer.state.best_model_checkpoint
    print("Best checkpoint:", best_ckpt)

    # Save clean BEST folder
    best_dir = OUTPUT_DIR / f"roberta_iemocap_final_seed{seed}_BEST"
    if best_dir.exists():
        shutil.rmtree(best_dir)
    shutil.copytree(best_ckpt, best_dir)
    tok.save_pretrained(best_dir)
    print("Saved BEST folder:", best_dir)

    test_metrics = trainer.evaluate(test_ds_full)
    print("TEST:", test_metrics)

    rows.append({
        "seed": seed,
        "best_dir": str(best_dir),
        "test_acc": float(test_metrics["eval_acc"]),
        "test_weighted_f1": float(test_metrics["eval_weighted_f1"]),
        "test_macro_f1": float(test_metrics["eval_macro_f1"]),
    })

df = pd.DataFrame(rows)
df


### Notes
- For GitHub, consider moving **training** into a script (`train.py`) so it can be run headlessly and used in CI.
- Keep large model checkpoints out of git (use releases, Hugging Face Hub, or external storage).
