
# 🚀 CPT/DAPT + Tiny Anchor with QLoRA (Minimal Notebook)
This notebook scaffolds **Domain-Adaptive Pretraining (CPT/DAPT)** on your docs using **QLoRA**, 
mixed with a **small instruction anchor (5–10%)** to preserve instruction-following.

**Steps**: chunk docs → build CPT dataset → (what is a collator) → interleave CPT+anchor → load model → train → guardrails → eval.


In [None]:

# # If on Colab/fresh VM, uncomment to install:
# !pip install -U transformers accelerate peft datasets bitsandbytes sentencepiece einops trl tensorboard

import os, re, json, math, glob, random
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Any

import torch
from datasets import load_dataset, interleave_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Trainer, TrainingArguments
from transformers.trainer_callback import EarlyStoppingCallback
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

CFG = {
    "run_name": "dapt_gemma3_v0",
    "model_name": "google/gemma-3-27b-it",  # or base if available
    "block_size": 2048,
    "pack_factor": 4,
    "cpt_weight": 0.9,
    "anchor_weight": 0.1,
    "epochs": 1,
    "lr": 1.0e-4,
    "warmup_ratio": 0.03,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 16,
    "eval_steps": 250,
    "save_steps": 500,
    "logging_steps": 25,
    "seed": 42,
    "use_bf16": True,
    "gradient_checkpointing": True,
    "output_dir": "./outputs_dapt",
    "raw_dir": "./data/raw",
    "cpt_jsonl": "./data/processed/cpt.jsonl",
    "anchor_jsonl": "./data/processed/anchor_instr.jsonl",
    "lora_r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "target_modules": ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
}
os.makedirs("./data/processed", exist_ok=True)
os.makedirs(CFG["output_dir"], exist_ok=True)

random.seed(CFG["seed"])
torch.manual_seed(CFG["seed"])

print("Config loaded:", CFG)



## What is a *collator*?
A **collator** creates a uniform batch from variable-length items: it **pads** sequences, builds the **attention mask**, 
and ensures **labels** align with `input_ids`. For **CPT**, we set `labels == input_ids` (except padding set to -100) 
so the model learns next-token prediction on every token.


In [None]:

# --- Build CPT dataset from ./data/raw ---
def read_texts_from_dir(raw_dir, exts=(".txt", ".md")):
    paths = []
    for ext in exts:
        paths += glob.glob(os.path.join(raw_dir, f"**/*{ext}"), recursive=True)
    texts = []
    for p in paths:
        try:
            with open(p, "r", encoding="utf-8", errors="ignore") as f:
                t = f.read()
                t = re.sub(r"[ \t]+\n", "\n", t)
                t = re.sub(r"\n{3,}", "\n\n", t)
                texts.append(t.strip())
        except Exception as e:
            print("Skipping", p, e)
    return texts

def naive_paragraph_chunk(text, max_chars=4000, min_chars=1200):
    paras = [p.strip() for p in text.split("\n\n") if p.strip()]
    chunks, buf, size = [], [], 0
    for p in paras:
        if size + len(p) > max_chars and size >= min_chars:
            chunks.append("\n\n".join(buf)); buf, size = [], 0
        buf.append(p); size += len(p) + 2
    if buf: chunks.append("\n\n".join(buf))
    return chunks

def make_cpt_jsonl_from_dir(raw_dir, out_jsonl):
    texts = read_texts_from_dir(raw_dir)
    count = 0
    with open(out_jsonl, "w", encoding="utf-8") as f:
        for t in texts:
            for ch in naive_paragraph_chunk(t):
                if len(ch) < 200: 
                    continue
                f.write(json.dumps({"text": ch}, ensure_ascii=False) + "\n")
                count += 1
    print(f"Wrote {count} CPT chunks -> {out_jsonl}")

# Demo file if nothing exists
if not Path(CFG["cpt_jsonl"]).exists():
    os.makedirs(CFG["raw_dir"], exist_ok=True)
    with open("./data/raw/demo.txt", "w") as f:
        f.write(("Screening Service...\n" * 50))
    make_cpt_jsonl_from_dir(CFG["raw_dir"], CFG["cpt_jsonl"])
else:
    print("Found CPT jsonl:", CFG["cpt_jsonl"])


In [None]:

# --- Tiny anchor (Alpaca-like) ---
def write_demo_anchor(out_jsonl):
    demo = [
        {"instruction": "Explain the Screening Service in one paragraph.", "input":"", "output":"It performs low-latency fraud checks by querying a local cache..."},
        {"instruction": "List two benefits of local cache screening.", "input":"", "output":"Lower latency and resilience to transient central outages."}
    ]
    with open(out_jsonl, "w", encoding="utf-8") as f:
        for r in demo:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")
    print("Wrote anchor demo ->", out_jsonl)

if not Path(CFG["anchor_jsonl"]).exists():
    write_demo_anchor(CFG["anchor_jsonl"])
else:
    print("Found anchor jsonl:", CFG["anchor_jsonl"])


In [None]:

# --- Tokenizer & pack helpers ---
tokenizer = AutoTokenizer.from_pretrained(CFG["model_name"], use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def group_texts(examples, block_size):
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_len = len(concatenated["input_ids"])
    total_len = (total_len // block_size) * block_size
    result = {k: [t[i:i+block_size] for i in range(0, total_len, block_size)] for k, t in concatenated.items()}
    result["labels"] = result["input_ids"].copy()
    return result

def load_cpt_dataset(jsonl_path, block_size=2048, add_eos=True):
    ds = load_dataset("json", data_files=jsonl_path, split="train")
    max_len = CFG["block_size"] * CFG["pack_factor"]
    def tok_fn(batch):
        texts = batch["text"]
        if add_eos and tokenizer.eos_token:
            texts = [t + tokenizer.eos_token for t in texts]
        out = tokenizer(texts, truncation=True, max_length=max_len, add_special_tokens=False)
        return out
    tokenized = ds.map(tok_fn, batched=True, remove_columns=ds.column_names)
    tokenized = tokenized.map(lambda e: group_texts(e, block_size), batched=True)
    return tokenized

def build_anchor_prompt(rec):
    instr = rec.get("instruction","").strip()
    inpt = rec.get("input","").strip()
    if inpt:
        return f"[INST] {instr}\n{inpt} [/INST]\n"
    return f"[INST] {instr} [/INST]\n"

def load_anchor_dataset(jsonl_path):
    ds = load_dataset("json", data_files=jsonl_path, split="train")
    def tok_map(batch):
        prompts = [build_anchor_prompt(r) for r in batch]
        outs = [r.get("output","") for r in batch]
        prompt_tok = tokenizer(prompts, add_special_tokens=False)
        out_tok = tokenizer(outs, add_special_tokens=False)
        input_ids, labels = [], []
        for p_ids, o_ids in zip(prompt_tok["input_ids"], out_tok["input_ids"]):
            ids = p_ids + o_ids + ([tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else [])
            lab = [-100]*len(p_ids) + o_ids + ([tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else [])
            input_ids.append(ids); labels.append(lab)
        return {"input_ids": input_ids, "labels": labels}
    tokenized = ds.map(tok_map, batched=True, remove_columns=ds.column_names)
    tokenized = tokenized.map(lambda e: group_texts(e, CFG["block_size"]), batched=True)
    return tokenized

cpt_ds = load_cpt_dataset(CFG["cpt_jsonl"], block_size=CFG["block_size"])
anchor_ds = load_anchor_dataset(CFG["anchor_jsonl"])
train_ds = interleave_datasets([cpt_ds, anchor_ds], probabilities=[CFG["cpt_weight"], CFG["anchor_weight"]], seed=CFG["seed"])

print("CPT examples:", len(cpt_ds))
print("Anchor examples:", len(anchor_ds))
print("Interleaved length:", len(train_ds))


In [None]:

# --- Collator ---
@dataclass
class DataCollatorForCausalPairs:
    tokenizer: Any
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features]
        labels = [torch.tensor(f["labels"], dtype=torch.long) for f in features]
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
        attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
        return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}

collator = DataCollatorForCausalPairs(tokenizer)


In [None]:

# --- Load 4-bit base & apply LoRA ---
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 if CFG["use_bf16"] else torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(CFG["model_name"], quantization_config=bnb_config, device_map="auto", trust_remote_code=True)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=CFG["gradient_checkpointing"])

peft_cfg = LoraConfig(
    r=CFG["lora_r"],
    lora_alpha=CFG["lora_alpha"],
    lora_dropout=CFG["lora_dropout"],
    target_modules=CFG["target_modules"],
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_cfg)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")


In [None]:

# --- Trainer ---
args = TrainingArguments(
    output_dir=CFG["output_dir"],
    run_name=CFG["run_name"],
    num_train_epochs=CFG["epochs"],
    learning_rate=CFG["lr"],
    warmup_ratio=CFG["warmup_ratio"],
    per_device_train_batch_size=CFG["per_device_train_batch_size"],
    gradient_accumulation_steps=CFG["gradient_accumulation_steps"],
    logging_steps=CFG["logging_steps"],
    save_steps=CFG["save_steps"],
    evaluation_strategy="no",  # set to "steps" if you add eval_dataset for ppl
    bf16=CFG["use_bf16"],
    fp16=not CFG["use_bf16"],
    gradient_checkpointing=CFG["gradient_checkpointing"],
    lr_scheduler_type="cosine",
    report_to=["tensorboard"],
    optim="paged_adamw_8bit",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    data_collator=collator,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

# To start training, uncomment:
# trainer.train()



### Guardrails
- Start with **CPT:ANCHOR = 0.9 : 0.1**. If instruction-following drops, go to **0.85 : 0.15** or **0.8 : 0.2**.
- LR **1e-4** → if unstable or hallucinations increase, drop to **7e-5**.
- Limit to **1 epoch** first; only extend if perplexity & RAG metrics keep improving.
- Ensure **dedup & PII scrub** on your corpus.


In [None]:

# --- Simple Perplexity Eval (dev texts) ---
import numpy as np

@torch.no_grad()
def perplexity_on_texts(model, tokenizer, texts: List[str], max_length: int = 1024) -> float:
    model.eval()
    nlls = []
    for t in texts:
        enc = tokenizer(t, return_tensors="pt", truncation=True, max_length=max_length).to(model.device)
        labels = enc["input_ids"].clone()
        out = model(**enc, labels=labels)
        nlls.append(out.loss.detach().float().item())
    ppl = math.exp(float(np.mean(nlls)))
    return ppl

# Example (replace with held-out paragraphs):
# dev_texts = ["Held-out paragraph ...", "Another paragraph ..."]
# print("Perplexity:", perplexity_on_texts(model, tokenizer, dev_texts))

def run_rag_eval_stub(checkpoint_dir: str):
    # Integrate your  RAG eval harness here.
    return {"faithfulness": None, "completeness": None, "specificity": None, "conciseness": None, "schema_pass_rate": None}
