In [1]:
# install required packages
!pip install bitsandbytes transformers datasets peft faiss-cpu accelerate sentence-transformers -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.6/23.6 MB[0m [31m37.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:


import json
import random
import numpy as np
from transformers import AutoTokenizer


# CONFIG

DATASET_PATH = "/content/dataset.json"   # modify if needed
MODEL_NAME = "Qwen/Qwen2.5-1.5B"         # <-- changed to Qwen 2.5 1.5B tokenizer
SAMPLE_SIZE = 100                        # how many examples to sample
PERCENTILE = 98                          # target percentile
ROUND_MULTIPLE = 64                      # round up to nearest 64
 
# LOAD TOKENIZER
 
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


# LOAD DATASET

with open(DATASET_PATH, "r") as f:
    data = json.load(f)

print(f"Loaded dataset examples: {len(data)}")


# SAMPLE SUBSET

sampled = random.sample(data, min(SAMPLE_SIZE, len(data)))


# BUILD TRAINING PROMPTS & MEASURE TOKEN LENGTHS

def build_prompt(query, retrieved_chunks):
    # retrieved_chunks = list of strings (we simulate 3 empty ones for now)
    c1, c2, c3 = retrieved_chunks

    return f"""
[QUERY]
{query}

[RETRIEVED_CONTEXT]
1. {c1}
2. {c2}
3. {c3}

[INSTRUCTION]
Generate a clinical report. Do NOT include disclaimers.
""".strip()


lengths = []

print("Tokenizing samples...")
for item in sampled:
    q = item["query"]

    # For this stage we don't need real retrieval, so use blanks
    prompt = build_prompt(q, ["", "", ""])

    tokens = tokenizer.encode(prompt)
    lengths.append(len(tokens))

lengths = np.array(lengths)


# COMPUTE OPTIMAL MAX SEQ LEN

raw_percentile = int(np.percentile(lengths, PERCENTILE))

# Round up to nearest multiple of 64
def round_up(x, multiple):
    return ((x + multiple - 1) // multiple) * multiple

optimal_seq = round_up(raw_percentile, ROUND_MULTIPLE)


# RESULTS


print("TOKEN LENGTH ANALYSIS")

print("Sample count:", len(sampled))
print("Min tokens:", lengths.min())
print("Median tokens:", int(np.median(lengths)))
print("Max tokens:", lengths.max())
print(f"{PERCENTILE}th percentile:", raw_percentile)
print("Rounded optimal seq length:", optimal_seq)


print(f"\n Recommended max_seq_length = {optimal_seq}\n")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Loaded dataset examples: 478
Tokenizing samples...
TOKEN LENGTH ANALYSIS
Sample count: 100
Min tokens: 50
Median tokens: 56
Max tokens: 64
98th percentile: 62
Rounded optimal seq length: 64

 Recommended max_seq_length = 64



In [3]:


import json

with open("dataset.json","r") as f:
    data = json.load(f)

for i, d in enumerate(data[:10]):
    print(i, d)


bad = [i for i, d in enumerate(data) if "report" not in d]
print(" Missing report count:", len(bad))
print(" Indices with missing report:", bad[:50])


import json

path = "/content/dataset.json"   # <-- update if different

# Load
with open(path, "r") as f:
    data = json.load(f)

print("Original size:", len(data))

# Clean broken entries
cleaned = [d for d in data if "report" in d and isinstance(d["report"], str) and len(d["report"].strip()) > 0]

print("Cleaned size:", len(cleaned))
print("Removed:", len(data) - len(cleaned))

# Overwrite the file
with open(path, "w") as f:
    json.dump(cleaned, f, indent=2)

print(" dataset.json cleaned & overwritten successfully.")


0 {'query': "I'm John, 35. Diagnosed with strep throat. What antibiotic is standard?", 'report': 'Penicillin or Amoxicillin are first-line treatments for strep throat.'}
1 {'query': 'My name is Lisa, 28. I have bacterial vaginosis. What medication works?', 'report': 'Metronidazole oral tablets or vaginal gel are standard treatments for BV.'}
2 {'query': "I'm David, 70. Diagnosed with pneumonia. What antibiotics are used?", 'report': 'Community-acquired pneumonia is typically treated with Azithromycin or Amoxicillin.'}
3 {'query': 'My name is Sarah, 42. I have shingles. What medication helps?', 'report': 'Antiviral medications like Acyclovir or Valacyclovir are prescribed for shingles.'}
4 {'query': "I'm Mike, 58. Diagnosed with gout attack. What treats the pain?", 'report': 'Colchicine or NSAIDs like Indomethacin are used for acute gout pain.'}
5 {'query': 'My name is Anna, 31. I have urinary tract infection. What antibiotic?', 'report': 'Nitrofurantoin or Trimethoprim-sulfamethoxazole

In [4]:


 
# 1) LOAD DATASET
 
import json
import random
from tqdm import tqdm

with open("dataset.json", "r") as f:
    data = json.load(f)

print("Loaded dataset examples:", len(data))

 
# 2) BUILD FAISS INDEX FROM REPORTS ONLY
 

from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

embedder = SentenceTransformer("all-mpnet-base-v2")

reports = [d["report"] for d in data]
report_embeddings = embedder.encode(reports, convert_to_numpy=True)

dim = report_embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
faiss.normalize_L2(report_embeddings)
index.add(report_embeddings)

# map index → report text
id2report = {i: reports[i] for i in range(len(reports))}

print("FAISS index built with", len(reports), "reports")

 
# 3) RETRIEVAL FUNCTION (TOP-3)
 
def retrieve_top3(query_text):
    q_emb = embedder.encode([query_text], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)

    scores, hits = index.search(q_emb, 10)  # fetch top-10 to boost filtering
    hits = hits[0].tolist()

    # Simply pick top 3 reports for token estimation
    retrieved = [id2report[i] for i in hits[:3]]
    return retrieved

 
# 4) BUILD FULL PROMPT (REAL TRAINING FORMAT)
 
def build_prompt(query, chunks):
    prompt = f"""
[QUERY]
{query}

[RETRIEVED_CONTEXT]
1. {chunks[0] if len(chunks)>0 else ""}
2. {chunks[1] if len(chunks)>1 else ""}
3. {chunks[2] if len(chunks)>2 else ""}

[INSTRUCTION]
You a medical assistant generate a clinical report. Do NOT include disclaimers.
Given query and the relevant documents make a report using them.
""".strip()
    return prompt

 
# 5) TOKENIZE WITH QWEN TOKENIZER
 

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B")

# Sample randomly to avoid tokenizing all 500
sample_data = random.sample(data, min(100, len(data)))

token_lengths = []

print("Tokenizing real training prompts...")

for sample in tqdm(sample_data):
    q = sample["query"]
    chunks = retrieve_top3(q)
    full_prompt = build_prompt(q, chunks)

    tokens = tokenizer(full_prompt, return_tensors="pt")
    token_lengths.append(tokens.input_ids.shape[1])

 
# 6) ANALYZE LENGTHS
 
import numpy as np

arr = np.array(token_lengths)

print("\nTOKEN LENGTH ANALYSIS (FULL RAG PROMPTS)")
print("Sample count:", len(arr))
print("Min tokens:", arr.min())
print("Median tokens:", np.median(arr))
print("Max tokens:", arr.max())
print("98th percentile:", np.percentile(arr, 98))

recommended = int(np.percentile(arr, 98)) + 50  # safety pad
print("\nRecommended max_seq_length =", recommended)


Loaded dataset examples: 478


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

FAISS index built with 478 reports
Tokenizing real training prompts...


100%|██████████| 100/100 [00:01<00:00, 80.84it/s]


TOKEN LENGTH ANALYSIS (FULL RAG PROMPTS)
Sample count: 100
Min tokens: 111
Median tokens: 130.0
Max tokens: 148
98th percentile: 148.0

Recommended max_seq_length = 198





In [None]:


import os, re, json, time, random, math, warnings
from typing import List, Dict, Any, Tuple
import numpy as np
import faiss
import torch
import spacy
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer
from sentence_transformers import SentenceTransformer
from datasets import Dataset, DatasetDict
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from cryptography.fernet import Fernet
from sklearn.model_selection import train_test_split
from datetime import datetime

print("Torch:", torch.__version__)
print("Device:", "cuda" if torch.cuda.is_available() else "cpu")


# Paths & basic config

DRIVE_BASE = "/content/drive/MyDrive/qwen_finetune"
os.makedirs(DRIVE_BASE, exist_ok=True)

DATASET_JSON = os.path.join(DRIVE_BASE, "dataset.json")  # put your dataset.json here
ARTIFACTS_DIR = os.path.join(DRIVE_BASE, "artifacts")
os.makedirs(ARTIFACTS_DIR, exist_ok=True)

FAISS_PATH = os.path.join(ARTIFACTS_DIR, "reports_index.faiss")
META_PATH = os.path.join(ARTIFACTS_DIR, "reports_metadata.json")

TRAIN_JSONL = os.path.join(ARTIFACTS_DIR, "train_final.jsonl")
VAL_JSONL = os.path.join(ARTIFACTS_DIR, "val_final.jsonl")

FINETUNED_DIR = os.path.join(DRIVE_BASE, "qwen2p5_qlora_ft_merged")
os.makedirs(FINETUNED_DIR, exist_ok=True)


# User-tunable settings

EMBED_MODEL = "all-MiniLM-L6-v2"
BIOMED_NER = "d4data/biomedical-ner-all"
DEFAULT_BASE_MODEL = "Qwen/Qwen2.5-1.5B"   
RANDOM_SEED = 42
SIM_THRESHOLD = 0.30
TOPK = 3
MASK_PROB = 0.60
TEST_SIZE = 0.30

NUM_EPOCHS = 20
PER_DEVICE_BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 8
LEARNING_RATE = 3e-4
MAX_LENGTH = 200
LORA_R = 32
LORA_ALPHA = 64
LORA_DROPOUT = 0.05

device = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)


# Utility functions (masking / NER / simple PHI masking)

PII_LABELS = {"PERSON","PATIENT","LOCATION","ADDRESS","EMAIL","PHONE","ID","DATE"}

def simple_regex_mask(text: str) -> str:
    out = text
    out = re.sub(r"\b[A-Z][a-z]+ [A-Z][a-z]+\b", "[NAME]", out)
    out = re.sub(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+", "[EMAIL]", out)
    out = re.sub(r"\b(\+?\d{1,3}[-.\s]?)?(\d{3}[-.\s]?\d{3}[-.\s]?\d{4})\b", "[PHONE]", out)
    out = re.sub(r"\b\d{1,2}[\/\-\.\s]\d{1,2}[\/\-\.\s]\d{2,4}\b", "[DATE]", out)
    return out

try:
    from transformers import pipeline as hf_pipeline
    ner_pipe = hf_pipeline("ner", model=BIOMED_NER, tokenizer=BIOMED_NER, aggregation_strategy="simple", device=0 if device=="cuda" else -1)
    print("Loaded biomedical NER pipeline.")
except Exception as e:
    ner_pipe = None
    print("No biomedical NER available:", e)

def mask_with_ner(text: str) -> Tuple[str, Dict[str,str]]:
    if ner_pipe is None:
        return simple_regex_mask(text), {}
    try:
        res = ner_pipe(text)
    except Exception:
        return simple_regex_mask(text), {}
    mask_map = {}
    masked = text
    spans = []
    for ent in res:
        label = (ent.get("entity_group") or ent.get("entity") or "").upper()
        token = ent.get("word") or ent.get("entity") or ent.get("word")
        s = ent.get("start", None); e = ent.get("end", None)
        if s is None or e is None:
            continue
        if label in PII_LABELS:
            spans.append((s,e,label,token))
    spans_sorted = sorted(spans, key=lambda x:x[0], reverse=True)
    counters={}
    for s,e,label,orig in spans_sorted:
        counters[label] = counters.get(label,0)+1
        token = f"<PHI_{label}_{counters[label]}>"
        masked = masked[:s] + token + masked[e:]
        mask_map[token] = orig
    masked = simple_regex_mask(masked)
    return masked, mask_map


# Load dataset.json from Drive

if not os.path.exists(DATASET_JSON):
    raise FileNotFoundError(f"Put your dataset.json at: {DATASET_JSON}")

with open(DATASET_JSON, "r", encoding="utf-8") as fh:
    raw_data = json.load(fh)

assert isinstance(raw_data, list), "dataset.json must be a list of {query, report} pairs."
print(f"Loaded {len(raw_data)} examples")


# Build embedding model & FAISS on train reports

print("Loading sentence-transformers embedder...")
embedder = SentenceTransformer(EMBED_MODEL, device=device)
EMB_DIM = embedder.get_sentence_embedding_dimension()
print("Embedding dim:", EMB_DIM)

train_list, val_list = train_test_split(raw_data, test_size=TEST_SIZE, random_state=RANDOM_SEED)
print(f"Train size: {len(train_list)}, Val size: {len(val_list)}")

print("Building FAISS index over train reports...")
reports = [rec["report"] for rec in train_list]
if len(reports) == 0:
    raise RuntimeError("No reports found in train split.")
report_embs = embedder.encode(reports, convert_to_numpy=True).astype("float32")
faiss.normalize_L2(report_embs)
index = faiss.IndexFlatIP(EMB_DIM)
index.add(report_embs)
faiss.write_index(index, FAISS_PATH)
meta = [{"id": i, "report": reports[i], "query": train_list[i].get("query","")} for i in range(len(reports))]
with open(META_PATH, "w", encoding="utf-8") as fh:
    json.dump(meta, fh, indent=2)
print(f"FAISS index saved to {FAISS_PATH}. Metadata saved to {META_PATH}")


# Retrieval helper (entity/med boosting)

def extract_entities_simple(text: str) -> Dict[str,List[str]]:
    out = {}
    if ner_pipe is not None:
        try:
            res = ner_pipe(text)
            for ent in res:
                label = (ent.get("entity_group") or ent.get("entity") or "").upper()
                token = ent.get("word") or ent.get("entity") or ent.get("word")
                out.setdefault(label, []).append(token)
            return out
        except Exception:
            pass
    tokens = re.findall(r"\b[A-Za-z0-9\-]{3,}\b", text)
    words = [t.lower() for t in tokens]
    disease_candidates = [w for w in words if re.search(r"(itis|osis|emia|oma|drome|disease|infection|pneumonia|shingles|strep|gout|urinary|pneumonia)", w)]
    med_candidates = [w for w in words if re.search(r"(cin|cillin|azole|cycline|mycin|vir|azole|statin|profen|ibuprofen|aspirin|colchicine|metronidazole|azithromycin|amoxicillin)", w)]
    if disease_candidates: out["DISEASE"] = disease_candidates
    if med_candidates: out["MEDICATION"] = med_candidates
    return out

def retrieve_topk_reports(query: str, topk: int = TOPK, sim_threshold: float = SIM_THRESHOLD, fetch_k:int=None) -> List[Dict[str,Any]]:
    if fetch_k is None:
        fetch_k = max(10, topk * 4)
    q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32")
    faiss.normalize_L2(q_emb)
    D,I = index.search(q_emb, fetch_k)
    hits = []
    meta_local = meta
    query_ents = extract_entities_simple(query)
    for score, idx in zip(D[0], I[0]):
        if idx < 0 or idx >= len(meta_local):
            continue
        if float(score) < sim_threshold:
            continue
        rec = meta_local[int(idx)]
        rec_ents = extract_entities_simple(rec.get("report","") + " " + (rec.get("query","") or ""))
        overlap = 0
        for ent_type in ["MEDICATION","DISEASE","DISEASES","DISEASE"]:
            if ent_type in query_ents and ent_type in rec_ents:
                overlap += len(query_ents[ent_type]) * len(rec_ents[ent_type])
        boosted_score = float(score) * (1.0 + overlap/10.0) if overlap > 0 else float(score)
        hits.append({"idx": rec["id"], "report": rec["report"], "score": float(score), "boosted_score": boosted_score, "orig_idx": int(idx)})
    hits = sorted(hits, key=lambda x: x["boosted_score"], reverse=True)
    return hits[:topk]


# Build train/val jsonl files for finetuning

def build_prompt(query: str, chunks: List[str]) -> str:
    prompt_lines = []
    prompt_lines.append("[QUERY]")
    prompt_lines.append(query.strip())
    prompt_lines.append("")
    prompt_lines.append("[RETRIEVED_CONTEXT]")
    for i, c in enumerate(chunks, start=1):
        prompt_lines.append(f"{i}. {c.strip()}")
    for i in range(len(chunks)+1, TOPK+1):
        prompt_lines.append(f"{i}. ")
    prompt_lines.append("")
    prompt_lines.append("[INSTRUCTION]")
    prompt_lines.append("Generate a clinical report. Do NOT include disclaimers.")
    prompt = "\n".join(prompt_lines)
    return prompt

def create_dataset_jsonl(records: List[Dict[str,str]], out_path: str, mask_prob: float = MASK_PROB, sim_threshold:float = SIM_THRESHOLD):
    lines = []
    skipped = 0
    for rec in records:
        query = rec.get("query","").strip()
        report = rec.get("report","").strip()
        if not query or not report:
            continue
        hits = retrieve_topk_reports(query, topk=TOPK, sim_threshold=sim_threshold)
        if len(hits) == 0:
            skipped += 1
            continue
        chunks = [h["report"] for h in hits]
        do_mask = (random.random() < mask_prob)
        if do_mask:
            masked_query, _ = mask_with_ner(query)
            input_query = masked_query
        else:
            input_query = query
        prompt = build_prompt(input_query, chunks)
        obj = {"input": prompt, "output": report}
        lines.append(json.dumps(obj, ensure_ascii=False))
    with open(out_path, "w", encoding="utf-8") as fh:
        fh.write("\n".join(lines))
    print(f"Wrote {len(lines)} examples to {out_path} (skipped {skipped} examples due to no retrieval hits).")
    return len(lines)

print("Creating train_final.jsonl and val_final.jsonl ...")
n_train = create_dataset_jsonl(train_list, TRAIN_JSONL, mask_prob=MASK_PROB, sim_threshold=SIM_THRESHOLD)
n_val = create_dataset_jsonl(val_list, VAL_JSONL, mask_prob=0.0, sim_threshold=SIM_THRESHOLD)

print("Sample training example:")
with open(TRAIN_JSONL, "r", encoding="utf-8") as fh:
    for i, line in enumerate(fh):
        if i>0: break
        print(json.dumps(json.loads(line), indent=2))


# QLoRA finetuning using peft + bitsandbytes (8-bit quant)

print("Preparing for QLoRA fine-tuning (8-bit quantization)...")

tokenizer = AutoTokenizer.from_pretrained(DEFAULT_BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token":"<|pad|>"})

def load_jsonl_to_dataset(jsonl_path: str):
    records = []
    with open(jsonl_path, "r", encoding="utf-8") as fh:
        for ln in fh:
            ln = ln.strip()
            if not ln: continue
            obj = json.loads(ln)
            records.append(obj)
    return Dataset.from_list(records)

train_ds = load_jsonl_to_dataset(TRAIN_JSONL)
val_ds = load_jsonl_to_dataset(VAL_JSONL)
print("Train/examples:", len(train_ds), "Val/examples:", len(val_ds))

def tokenize_and_mask(example):
    inp = example["input"]
    tgt = example["output"]
    full_text = inp + "\n\n" + tgt + tokenizer.eos_token
    tokenized_full = tokenizer(full_text, truncation=True, max_length=MAX_LENGTH)
    input_tokenized = tokenizer(inp, truncation=True, max_length=MAX_LENGTH)
    input_len = len(input_tokenized["input_ids"])
    labels = tokenized_full["input_ids"].copy()
    for i in range(input_len):
        if i < len(labels):
            labels[i] = -100
    tokenized_full["labels"] = labels
    return tokenized_full

train_ds_tokenized = train_ds.map(tokenize_and_mask, remove_columns=train_ds.column_names, num_proc=1)
val_ds_tokenized = val_ds.map(tokenize_and_mask, remove_columns=val_ds.column_names, num_proc=1)

from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer, padding=True)

print("Loading base model in 8-bit (this may take a while)...")
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False
)

model = AutoModelForCausalLM.from_pretrained(
    DEFAULT_BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

model = prepare_model_for_kbit_training(model)

target_modules = ["q_proj","v_proj","k_proj","o_proj"]
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=target_modules,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
print("LoRA adapters added. Trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = os.path.join(ARTIFACTS_DIR, f"qlora_runs_{timestamp}")
os.makedirs(out_dir, exist_ok=True)

training_args = TrainingArguments(
    output_dir=out_dir,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=200 if len(train_ds_tokenized)>200 else max(1, len(train_ds_tokenized)//5),
    save_steps=200 if len(train_ds_tokenized)>200 else max(1, len(train_ds_tokenized)//5),
    logging_steps=50,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    fp16=True,
    warmup_ratio=0.03,
    load_best_model_at_end=True,
    greater_is_better=False,
    ddp_find_unused_parameters=False,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds_tokenized,
    eval_dataset=val_ds_tokenized if len(val_ds_tokenized)>0 else None,
    data_collator=data_collator,
)

print("Starting training...")
trainer.train()

lora_weights_dir = os.path.join(ARTIFACTS_DIR, "lora_weights")
os.makedirs(lora_weights_dir, exist_ok=True)
model.save_pretrained(lora_weights_dir)
print("Saved LoRA weights to:", lora_weights_dir)

print("Merging LoRA adapters into base model ...")
try:
    model = model.merge_and_unload()
except Exception as e:
    print("merge_and_unload() not available or failed:", e)
    # proceed with saving current model (adapters present)

print("Saving final model to:", FINETUNED_DIR)
model.save_pretrained(FINETUNED_DIR)
tokenizer.save_pretrained(FINETUNED_DIR)
print("Final fine-tuned model saved to Drive at:", FINETUNED_DIR)

artifacts_info = {
    "faiss_index": FAISS_PATH,
    "meta": META_PATH,
    "train_jsonl": TRAIN_JSONL,
    "val_jsonl": VAL_JSONL,
    "lora_weights": lora_weights_dir,
    "merged_model": FINETUNED_DIR,
    "notes": {
        "sim_threshold": SIM_THRESHOLD,
        "topk": TOPK,
        "mask_prob": MASK_PROB
    }
}
with open(os.path.join(ARTIFACTS_DIR, "artifacts_manifest.json"), "w", encoding="utf-8") as fh:
    json.dump(artifacts_info, fh, indent=2)


