In [2]:
!pip install transformers sentence-transformers faiss-cpu ipywidgets cryptography sacremoses spacy negspacy bitsandbytes accelerate -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.6/23.6 MB[0m [31m122.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m66.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m86.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for negspacy (pyproject.toml) ... [?25l[?25hdone


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
The token `token9911` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `token9911`


In [None]:
import os, re, json, time, random, math, warnings
from typing import List, Dict, Any, Tuple
import numpy as np
import faiss
import torch
import spacy
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer
from sentence_transformers import SentenceTransformer
from datasets import Dataset, DatasetDict
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from cryptography.fernet import Fernet
from sklearn.model_selection import train_test_split
from datetime import datetime

print("Torch:", torch.__version__)
print("Device:", "cuda" if torch.cuda.is_available() else "cpu")

DRIVE_BASE = "/content/drive/MyDrive/gemma_finetune"
os.makedirs(DRIVE_BASE, exist_ok=True)

DATASET_JSON = os.path.join(DRIVE_BASE, "dataset.json")
ARTIFACTS_DIR = os.path.join(DRIVE_BASE, "artifacts")
os.makedirs(ARTIFACTS_DIR, exist_ok=True)

# Define the split files
RAG_TRAIN_JSON = os.path.join(ARTIFACTS_DIR, "rag-train.json")
RAG_TEST_JSON = os.path.join(ARTIFACTS_DIR, "rag-test.json")

# Fine-tuning files
FINETUNE_TRAIN_JSONL = os.path.join(ARTIFACTS_DIR, "finetune-train.jsonl")
FINETUNE_TEST_JSONL = os.path.join(ARTIFACTS_DIR, "finetune-test.jsonl")

FAISS_PATH = os.path.join(ARTIFACTS_DIR, "reports_index.faiss")
META_PATH = os.path.join(ARTIFACTS_DIR, "reports_metadata.json")

FINETUNED_DIR = os.path.join(DRIVE_BASE, "gemma2b_qlora_ft_merged")
os.makedirs(FINETUNED_DIR, exist_ok=True)

EMBED_MODEL = "all-MiniLM-L6-v2"
DEFAULT_BASE_MODEL = "google/gemma-2b"
RANDOM_SEED = 42
SIM_THRESHOLD = 0.30
TOPK = 3
MASK_PROB = 0.60
TRAIN_SIZE = 0.80  # 80-20 split

NUM_EPOCHS = 20
PER_DEVICE_BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 8
LEARNING_RATE = 2e-4  
MAX_LENGTH = 200
LORA_R = 32
LORA_ALPHA = 64
LORA_DROPOUT = 0.05

device = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

try:
    nlp = spacy.load("en_core_web_sm")
    print("Loaded spaCy NER model.")
except Exception as e:
    print("spaCy model not found, downloading...")
    import subprocess
    subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
    nlp = spacy.load("en_core_web_sm")
    print("Loaded spaCy NER model.")

# Only mask names and personal info,
PII_LABELS = {"PERSON", "GPE", "LOC"}

def simple_regex_mask(text: str) -> str:
    out = text
    # Only mask names, emails, phones
    out = re.sub(r"\b[A-Z][a-z]+ [A-Z][a-z]+\b", "[NAME]", out)
    out = re.sub(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+", "[EMAIL]", out)
    out = re.sub(r"\b(\+?\d{1,3}[-.\s]?)?(\d{3}[-.\s]?\d{3}[-.\s]?\d{4})\b", "[PHONE]", out)
    
    return out

def mask_with_spacy(text: str) -> Tuple[str, Dict[str,str]]:
    doc = nlp(text)
    mask_map = {}
    masked = text

    entities = sorted(doc.ents, key=lambda x: x.start_char, reverse=True)

    for ent in entities:
        # Only mask PERSON, GPE, LOC
        if ent.label_ in PII_LABELS:
            placeholder = f"<{ent.label_}_{len(mask_map)+1}>"
            mask_map[placeholder] = ent.text
            masked = masked[:ent.start_char] + placeholder + masked[ent.end_char:]

    masked = simple_regex_mask(masked)

    return masked, mask_map

def apply_masking_to_dataset(records: List[Dict[str,str]], mask_prob: float = MASK_PROB) -> List[Dict[str,str]]:
    """Apply masking to a dataset with given probability"""
    masked_records = []
    masked_count = 0

    for rec in records:
        query = rec.get("query", "").strip()
        report = rec.get("report", "").strip()

        if not query or not report:
            continue

        # Apply masking with probability
        do_mask = (random.random() < mask_prob)

        if do_mask:
            masked_query, mask_map = mask_with_spacy(query)
            masked_records.append({
                "query": masked_query,
                "report": report,
                "original_query": query,  # Keep original for reference
                "masked": True
            })
            masked_count += 1
        else:
            masked_records.append({
                "query": query,
                "report": report,
                "masked": False
            })

    print(f"Applied masking: {masked_count}/{len(masked_records)} masked ({masked_count/len(masked_records)*100:.1f}%)")
    return masked_records

# Create 80-20 split files WITH MASKING
def create_rag_split_files():
    """Create rag-train.json and rag-test.json with 80-20 split AND MASKING"""
    if not os.path.exists(DATASET_JSON):
        raise FileNotFoundError(f"Put your dataset.json at: {DATASET_JSON}")

    with open(DATASET_JSON, "r", encoding="utf-8") as fh:
        raw_data = json.load(fh)

    assert isinstance(raw_data, list), "dataset.json must be a list of {query, report} pairs."
    print(f"Loaded {len(raw_data)} examples")

    # Create 80-20 split
    train_data, test_data = train_test_split(raw_data, train_size=TRAIN_SIZE, random_state=RANDOM_SEED)

    # APPLY MASKING TO BOTH SPLITS
    print("Applying masking to RAG train split...")
    train_data_masked = apply_masking_to_dataset(train_data, mask_prob=MASK_PROB)

    print("Applying masking to RAG test split...")
    test_data_masked = apply_masking_to_dataset(test_data, mask_prob=MASK_PROB)

    # Save the MASKED split files
    with open(RAG_TRAIN_JSON, "w", encoding="utf-8") as fh:
        json.dump(train_data_masked, fh, indent=2)

    with open(RAG_TEST_JSON, "w", encoding="utf-8") as fh:
        json.dump(test_data_masked, fh, indent=2)

    print(f"Created MASKED RAG split files:")
    print(f"  - {RAG_TRAIN_JSON}: {len(train_data_masked)} examples (80%, {MASK_PROB*100}% masked)")
    print(f"  - {RAG_TEST_JSON}: {len(test_data_masked)} examples (20%, {MASK_PROB*100}% masked)")

    return train_data_masked, test_data_masked

# Load or create RAG split files
if os.path.exists(RAG_TRAIN_JSON) and os.path.exists(RAG_TEST_JSON):
    print("Loading existing MASKED RAG split files...")
    with open(RAG_TRAIN_JSON, "r", encoding="utf-8") as fh:
        rag_train_data = json.load(fh)
    with open(RAG_TEST_JSON, "r", encoding="utf-8") as fh:
        rag_test_data = json.load(fh)
else:
    print("Creating new MASKED RAG split files...")
    rag_train_data, rag_test_data = create_rag_split_files()

print("Loading sentence-transformers embedder...")
embedder = SentenceTransformer(EMBED_MODEL, device=device)
EMB_DIM = embedder.get_sentence_embedding_dimension()
print("Embedding dim:", EMB_DIM)

print("Building FAISS index over MASKED RAG-TRAIN QUERIES...")
train_queries = [rec["query"] for rec in rag_train_data]
if len(train_queries) == 0:
    raise RuntimeError("No queries found in rag-train split.")

query_embs = embedder.encode(train_queries, convert_to_numpy=True).astype("float32")
faiss.normalize_L2(query_embs)
index = faiss.IndexFlatIP(EMB_DIM)
index.add(query_embs)
faiss.write_index(index, FAISS_PATH)

# Metadata from MASKED rag-train for retrieval
meta = [{"id": i, "query": rag_train_data[i]["query"], "report": rag_train_data[i]["report"]} for i in range(len(rag_train_data))]
with open(META_PATH, "w", encoding="utf-8") as fh:
    json.dump(meta, fh, indent=2)
print(f"FAISS index saved to {FAISS_PATH}. Metadata saved to {META_PATH}")

def retrieve_topk_query_pairs(current_query: str, current_idx: int = -1, topk: int = TOPK, sim_threshold: float = SIM_THRESHOLD) -> List[Dict[str,Any]]:
    """
    Retrieve top-k similar queries from MASKED RAG-TRAIN data
    """
    fetch_k = max(10, topk * 4)

    q_emb = embedder.encode([current_query], convert_to_numpy=True).astype("float32")
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, fetch_k)

    hits = []
    for score, idx in zip(D[0], I[0]):
        if idx < 0 or idx >= len(meta):
            continue
        if float(score) < sim_threshold:
            continue

        # Skip if this is the same exact query (Avoiding self-retrieval) 
        if current_idx != -1 and idx == current_idx:
            continue

        rec = meta[int(idx)]
        hits.append({
            "similar_query": rec["query"],
            "report": rec["report"],
            "score": float(score),
            "orig_idx": int(idx)
        })

        if len(hits) >= topk:
            break

    return hits

def build_prompt(query: str, retrieved_pairs: List[Dict[str,Any]]) -> str:
    prompt_lines = []
    prompt_lines.append("[QUERY]")
    prompt_lines.append(query.strip())
    prompt_lines.append("")
    prompt_lines.append("[SIMILAR_CASES]")
    for i, pair in enumerate(retrieved_pairs, start=1):
        prompt_lines.append(f"Case {i}:")
        prompt_lines.append(f"Query: {pair['similar_query']}")
        prompt_lines.append(f"Report: {pair['report']}")
        prompt_lines.append("")
    prompt_lines.append("[INSTRUCTION]")
    prompt_lines.append("Generate a clinical report. Do NOT include disclaimers.")
    prompt = "\n".join(prompt_lines)
    return prompt

def create_finetune_jsonl(records: List[Dict[str,str]], out_path: str, apply_masking: bool = True, mask_prob: float = MASK_PROB, sim_threshold:float = SIM_THRESHOLD):
    """
    Create finetune JSONL files with retrieval from RAG-TRAIN data
    
    """
    lines = []
    skipped = 0
    additional_masked_count = 0
    sample_inputs = []

    for idx, rec in enumerate(records):
        original_query = rec.get("query","").strip()
        report = rec.get("report","").strip()
        if not original_query or not report:
            continue

        # For training data: pass index to avoid self-retrieval
        # For test data: no self-retrieval concern since test data is not in index
        current_idx = idx if apply_masking else -1

        hits = retrieve_topk_query_pairs(original_query, current_idx, topk=TOPK, sim_threshold=sim_threshold)
        if len(hits) == 0:
            skipped += 1
            continue

        # Use the query as-is (already masked from RAG files)
        
        final_query = original_query

        # Apply additional masking for data augmentation
        do_additional_mask = (random.random() < mask_prob) if apply_masking else False

        if do_additional_mask:
            # Re-mask the query create different masking patterns
            final_query, mask_map = mask_with_spacy(rec.get("original_query", original_query))
            additional_masked_count += 1
            mask_status = "ADDITIONALLY_MASKED"
        else:
            mask_status = "MASKED" if rec.get("masked", False) else "UNMASKED"

        prompt = build_prompt(final_query, hits)
        obj = {"input": prompt, "output": report}
        lines.append(json.dumps(obj, ensure_ascii=False))

        if len(sample_inputs) < 3:
            sample_inputs.append({
                "original_query": rec.get("original_query", original_query),
                "input_query": final_query,
                "mask_status": mask_status,
                "retrieved_count": len(hits),
                "retrieved_queries": [hit["similar_query"][:50] + "..." for hit in hits],
                "final_prompt": prompt[:500] + "..." if len(prompt) > 500 else prompt
            })

    print(f"Wrote {len(lines)} examples to {out_path} (skipped {skipped} examples due to no retrieval hits).")
    if apply_masking:
        print(f"Additional masking applied: {additional_masked_count}/{len(lines)}")

    print("SAMPLE INPUTS THAT WILL GO TO LLM:")
    for i, sample in enumerate(sample_inputs, 1):
        print(f"SAMPLE {i} - {sample['mask_status']}:")
        print(f"Original: {sample['original_query']}")
        print(f"Input:    {sample['input_query']}")
        print(f"Retrieved: {sample['retrieved_count']} cases from RAG-TRAIN")
        print(f"Retrieved queries: {sample['retrieved_queries']}")
        print(f"Prompt preview: {sample['final_prompt'][:200]}...")
        print()

    # Write to file
    with open(out_path, "w", encoding="utf-8") as fh:
        for line in lines:
            fh.write(line + "\n")

    return len(lines)

# Create finetune files - BOTH WITH MASKING
print("Creating finetune-train.jsonl and finetune-test.jsonl ...")
n_train = create_finetune_jsonl(rag_train_data, FINETUNE_TRAIN_JSONL, apply_masking=True, mask_prob=MASK_PROB, sim_threshold=SIM_THRESHOLD)
n_test = create_finetune_jsonl(rag_test_data, FINETUNE_TEST_JSONL, apply_masking=True, mask_prob=MASK_PROB, sim_threshold=SIM_THRESHOLD)

# Verification

time.sleep(2)

if os.path.exists(FINETUNE_TRAIN_JSONL):
    print(f" Finetune train file created: {FINETUNE_TRAIN_JSONL}")
    train_size = sum(1 for _ in open(FINETUNE_TRAIN_JSONL, 'r'))
    print(f" Finetune train examples: {train_size}")
else:
    print(f" Finetune train file missing: {FINETUNE_TRAIN_JSONL}")

if os.path.exists(FINETUNE_TEST_JSONL):
    print(f" Finetune test file created: {FINETUNE_TEST_JSONL}")
    test_size = sum(1 for _ in open(FINETUNE_TEST_JSONL, 'r'))
    print(f" Finetune test examples: {test_size}")
else:
    print(f" Finetune test file missing: {FINETUNE_TEST_JSONL}")

print("Sample training example:")
if os.path.exists(FINETUNE_TRAIN_JSONL):
    with open(FINETUNE_TRAIN_JSONL, "r", encoding="utf-8") as fh:
        for i, line in enumerate(fh):
            if i>0: break
            print(json.dumps(json.loads(line), indent=2))
else:
    print("Finetune train file not available for preview")


print("Preparing for QLoRA fine-tuning...")

tokenizer = AutoTokenizer.from_pretrained(DEFAULT_BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token":"<|pad|>"})

def load_jsonl_to_dataset(jsonl_path: str):
    records = []
    if not os.path.exists(jsonl_path):
        print(f"Warning: {jsonl_path} not found")
        return Dataset.from_list([])

    with open(jsonl_path, "r", encoding="utf-8") as fh:
        for ln in fh:
            ln = ln.strip()
            if not ln: continue
            obj = json.loads(ln)
            records.append(obj)
    return Dataset.from_list(records)

# Load from finetune files
train_ds = load_jsonl_to_dataset(FINETUNE_TRAIN_JSONL)
val_ds = load_jsonl_to_dataset(FINETUNE_TEST_JSONL)  # Using test as validation
print("Train/examples:", len(train_ds), "Val/examples:", len(val_ds))

# Skip training if no data
if len(train_ds) == 0:
    print("ERROR: No training data found. Check file paths and Drive sync.")
    exit()

def tokenize_and_mask(example):
    inp = example["input"]
    tgt = example["output"]
    full_text = inp + "\n\n" + tgt + tokenizer.eos_token
    tokenized_full = tokenizer(full_text, truncation=True, max_length=MAX_LENGTH)
    input_tokenized = tokenizer(inp, truncation=True, max_length=MAX_LENGTH)
    input_len = len(input_tokenized["input_ids"])
    labels = tokenized_full["input_ids"].copy()
    for i in range(input_len):
        if i < len(labels):
            labels[i] = -100
    tokenized_full["labels"] = labels
    return tokenized_full

train_ds_tokenized = train_ds.map(tokenize_and_mask, remove_columns=train_ds.column_names, num_proc=1)
val_ds_tokenized = val_ds.map(tokenize_and_mask, remove_columns=val_ds.column_names, num_proc=1)

from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer, padding=True)

print("Loading base model in 8-bit...")
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False
)

model = AutoModelForCausalLM.from_pretrained(
    DEFAULT_BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

model = prepare_model_for_kbit_training(model)

target_modules = ["q_proj","v_proj","k_proj","o_proj"]
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=target_modules,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
print("LoRA adapters added. Trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = os.path.join(ARTIFACTS_DIR, "qlora_runs")
os.makedirs(out_dir, exist_ok=True)


training_args = TrainingArguments(
    output_dir=out_dir,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=200 if len(train_ds_tokenized)>200 else max(1, len(train_ds_tokenized)//5),
    save_steps=200 if len(train_ds_tokenized)>200 else max(1, len(train_ds_tokenized)//5),
    logging_steps=50,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    fp16=True,
    warmup_ratio=0.03,
    load_best_model_at_end=True,
    greater_is_better=False,
    ddp_find_unused_parameters=False,
    push_to_hub=False,
   
    max_grad_norm=1.0,              
    dataloader_pin_memory=False,   
    gradient_checkpointing=True,    
    remove_unused_columns=False,    
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds_tokenized,
    eval_dataset=val_ds_tokenized if len(val_ds_tokenized)>0 else None,
    data_collator=data_collator,
)


print("Applied fixes: gradient clipping (max_grad_norm=1.0), lower learning rate, gradient checkpointing")
trainer.train()

lora_weights_dir = os.path.join(ARTIFACTS_DIR, "lora_weights")
os.makedirs(lora_weights_dir, exist_ok=True)
model.save_pretrained(lora_weights_dir)
print("Saved LoRA weights to:", lora_weights_dir)

print("Merging LoRA adapters into base model ...")
try:
    model = model.merge_and_unload()
except Exception as e:
    print("merge_and_unload() not available or failed:", e)

print("Saving final model to:", FINETUNED_DIR)
model.save_pretrained(FINETUNED_DIR)
tokenizer.save_pretrained(FINETUNED_DIR)
print("Final fine-tuned model saved to Drive at:", FINETUNED_DIR)

# Update artifacts manifest
artifacts_info = {
    "rag_train": RAG_TRAIN_JSON,
    "rag_test": RAG_TEST_JSON,
    "finetune_train": FINETUNE_TRAIN_JSONL,
    "finetune_test": FINETUNE_TEST_JSONL,
    "faiss_index": FAISS_PATH,
    "meta": META_PATH,
    "lora_weights": lora_weights_dir,
    "merged_model": FINETUNED_DIR,
    "notes": {
        "split_ratio": "80-20",
        "mask_probability": MASK_PROB,
        "sim_threshold": SIM_THRESHOLD,
        "topk": TOPK,
        "learning_rate": LEARNING_RATE,
        "stability_fixes": "gradient_clipping, lower_lr, checkpointing",
        "description": "ALL datasets contain 60% masked queries for consistent training and evaluation"
    }
}
with open(os.path.join(ARTIFACTS_DIR, "artifacts_manifest.json"), "w", encoding="utf-8") as fh:
    json.dump(artifacts_info, fh, indent=2)


print(f"  RAG Train: {RAG_TRAIN_JSON} (60% masked)")
print(f"   RAG Test: {RAG_TEST_JSON} (60% masked)")
print(f"   Finetune Train: {FINETUNE_TRAIN_JSONL} (60% masked)")
print(f"   Finetune Test: {FINETUNE_TEST_JSONL} (60% masked)")

