In [None]:
import json
import random
import os

# Paths
DATASET_JSON = "/content/drive/MyDrive/gemma_finetune/dataset.json"
RAG_JSON = "/content/drive/MyDrive/gemma_finetune/rag-data.json"
USER_QUERY_JSON = "/content/drive/MyDrive/gemma_finetune/user-query.json"

# Load original dataset
with open(DATASET_JSON, "r", encoding="utf-8") as f:
    dataset = json.load(f)

# Shuffle for random split
random.shuffle(dataset)

# 70% RAG, 30% user queries
split_idx = int(len(dataset) * 0.7)
rag_data = dataset[:split_idx]
user_queries = dataset[split_idx:]

# Save RAG data
with open(RAG_JSON, "w", encoding="utf-8") as f:
    json.dump(rag_data, f, indent=2)
print(f"RAG data saved: {len(rag_data)} items -> {RAG_JSON}")

# Save user queries
with open(USER_QUERY_JSON, "w", encoding="utf-8") as f:
    json.dump(user_queries, f, indent=2)
print(f"User queries saved: {len(user_queries)} items -> {USER_QUERY_JSON}")


RAG data saved: 334 items -> /content/drive/MyDrive/gemma_finetune/rag-data.json
User queries saved: 144 items -> /content/drive/MyDrive/gemma_finetune/user-query.json


In [None]:
!pip install transformers sentence-transformers faiss-cpu ipywidgets cryptography sacremoses spacy negspacy bitsandbytes accelerate -q

In [None]:
# Imports & Setup

import os, json, re, time, random, warnings
from typing import List, Dict, Any, Tuple
import numpy as np
import faiss
import torch
import spacy
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from cryptography.fernet import Fernet


# Paths & configuration

BASE_DIR = "/content/drive/MyDrive/User_data_rag"
os.makedirs(BASE_DIR, exist_ok=True)

RAG_JSON = "/content/drive/MyDrive/gemma_finetune/rag-data.json"
ENCRYPTED_RAG_PATH = os.path.join(BASE_DIR, "rag_data.json.enc")
USER_QUERY_JSON = "/content/drive/MyDrive/gemma_finetune/user-query.json"
FINETUNED_MODEL_DIR = "/content/drive/MyDrive/gemma_finetune/gemma2b_qlora_ft_merged"

FAISS_INDEX_PATH = os.path.join(BASE_DIR, "rag_faiss.index")
FERNET_KEY_PATH = os.path.join(BASE_DIR, "fernet.key")
AUDIT_LOG_PATH = os.path.join(BASE_DIR, "audit.log")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

TOPK = 3
SIM_THRESHOLD = 0.3
EMBED_MODEL = "all-MiniLM-L6-v2"
PII_LABELS = {"PERSON","PATIENT","LOCATION","ADDRESS","EMAIL","PHONE","ID","DATE"}
MEDICAL_ENTS = {"MEDICATION","DISEASE","DISEASES","SYMPTOM","SYMPTOMS","SIGN_SYMPTOM","PROCEDURE","LABS","VITALS","AGE"}


# Encryption helpers

def get_or_create_fernet_key(path: str = FERNET_KEY_PATH) -> bytes:
    if os.path.exists(path):
        return open(path,"rb").read()
    key = Fernet.generate_key()
    open(path,"wb").write(key)
    return key

FERNET_KEY = get_or_create_fernet_key()
FERNET = Fernet(FERNET_KEY)

def encrypt_bytes(b: bytes) -> bytes:
    return FERNET.encrypt(b)

def decrypt_bytes(b: bytes) -> bytes:
    return FERNET.decrypt(b)

def write_encrypted_json(id_int: int, payload: Dict[str,Any]) -> str:
    path = os.path.join(BASE_DIR, f"{id_int}.json.enc")
    raw = json.dumps(payload, default=str).encode("utf-8")
    cipher = encrypt_bytes(raw)
    with open(path,"wb") as fh:
        fh.write(cipher)
    return path

def write_encrypted_rag(rag_data: List[Dict[str,Any]], path: str = ENCRYPTED_RAG_PATH):
    raw = json.dumps(rag_data, default=str).encode("utf-8")
    with open(path,"wb") as fh:
        fh.write(encrypt_bytes(raw))
    print(f"Encrypted RAG saved to {path}")

def load_decrypted_rag(path: str = ENCRYPTED_RAG_PATH) -> List[Dict[str,Any]]:
    with open(path,"rb") as fh:
        decrypted = decrypt_bytes(fh.read())
    return json.loads(decrypted)


# Encrypt RAG data if not already encrypted

if not os.path.exists(ENCRYPTED_RAG_PATH):
    with open(RAG_JSON, "r", encoding="utf-8") as fh:
        rag_data = json.load(fh)
    write_encrypted_rag(rag_data)
else:
    print("Encrypted RAG already exists.")

# Load decrypted RAG and build FAISS

rag_data = load_decrypted_rag()
print(f"Loaded {len(rag_data)} RAG documents.")

embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
EMB_DIM = embedder.get_sentence_embedding_dimension()
print("Embedding dim:", EMB_DIM)

index = faiss.IndexFlatIP(EMB_DIM)
texts = [rec["report"] for rec in rag_data]
embs = embedder.encode(texts, convert_to_numpy=True).astype("float32")
faiss.normalize_L2(embs)
index.add(embs)
faiss.write_index(index, FAISS_INDEX_PATH)
print(f"FAISS index built and saved at {FAISS_INDEX_PATH}")

# Store metadata
META_PATH = os.path.join(BASE_DIR, "rag_metadata.json")
meta = [{"id": i, "report": rag_data[i]["report"], "query": rag_data[i].get("query","")} for i in range(len(rag_data))]
with open(META_PATH, "w", encoding="utf-8") as fh:
    json.dump(meta, fh, indent=2)


# Load fine-tuned Gemma2 model

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False
)
tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_DIR, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    FINETUNED_MODEL_DIR,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

print("Fine-tuned model loaded in 8-bit.")


# NER + simple PII masking

try:
    ner_pipe = spacy.load("en_core_web_sm")
except Exception:
    import subprocess, sys
    subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
    ner_pipe = spacy.load("en_core_web_sm")

def simple_mask(text: str) -> Tuple[str, Dict[str,str]]:
    mask_map = {}
    out = text
    # names
    names = re.findall(r"\b[A-Z][a-z]+ [A-Z][a-z]+\b", out)
    for i,n in enumerate(names,1):
        token = f"<PHI_PERSON_{i}>"; mask_map[token]=n
        out = out.replace(n, token)
    # emails
    emails = re.findall(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+", out)
    for i,e in enumerate(emails,1):
        token = f"<PHI_EMAIL_{i}>"; mask_map[token]=e
        out = out.replace(e, token)
    # phones
    phones = re.findall(r"\b(\+?\d[\d\-\s]{7,}\d)\b", out)
    for i,p in enumerate(phones,1):
        token = f"<PHI_PHONE_{i}>"; mask_map[token]=p
        out = out.replace(p, token)
    # dates
    dates = re.findall(r"\b\d{1,2}[\/\-\.\s]\d{1,2}[\/\-\.\s]\d{2,4}\b", out)
    for i,d in enumerate(dates,1):
        token = f"<PHI_DATE_{i}>"; mask_map[token]=d
        out = out.replace(d, token)
    return out, mask_map


# Entity extraction helper

def extract_entities(text: str) -> Dict[str,Any]:
    doc = ner_pipe(text)
    ents = {}
    for ent in doc.ents:
        label = ent.label_.upper()
        ents.setdefault(label, []).append({"text":ent.text,"start":ent.start_char,"end":ent.end_char})
    return ents


# Negation marking

NEG_WORDS = r"\b(no|not|denies?|without|absent|negative for)\b"
def mark_negation(entities: Dict[str,Any], text: str, window_chars:int=40) -> Dict[str,Any]:
    for label, recs in entities.items():
        new_recs=[]
        for r in recs:
            r["negated"]=False
            s = r.get("start",0)
            context = text[max(0,s-window_chars):s].lower()
            if re.search(NEG_WORDS, context):
                r["negated"]=True
            new_recs.append(r)
        entities[label]=new_recs
    return entities

# RAG retrieval

def retrieve_topk(query_vec: np.ndarray, topk:int=TOPK, sim_threshold:float=SIM_THRESHOLD) -> List[Dict[str,Any]]:
    q = query_vec.reshape(1,-1).astype("float32")
    faiss.normalize_L2(q)
    D,I = index.search(q, topk*4)
    retrieved=[]
    for score, idx in zip(D[0], I[0]):
        if idx < 0 or idx >= len(meta):
            continue
        if score < sim_threshold:
            continue
        retrieved.append(meta[int(idx)])
        if len(retrieved)>=topk:
            break
    return retrieved


# Prompt builder (training format EXACT)

def build_prompt(query: str, retrieved: List[Dict[str,Any]]) -> str:
    lines = ["[QUERY]", query, "", "[RETRIEVED_CONTEXT]"]

    for i, r in enumerate(retrieved, 1):
        lines.append(f"{i}. {r['report']}")

    for i in range(len(retrieved) + 1, TOPK + 1):
        lines.append(f"{i}. ")

    lines.extend([
        "",
        "[INSTRUCTION]",
        "Generate a clinical report. Do NOT include disclaimers."
    ])

    return "\n".join(lines)


# STRICT FIX â€” remove prompt echoing

def generate_from_model(prompt: str, max_tokens: int = 256) -> str:
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(model.device)

    out = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id
    )[0]

    prompt_len = inputs["input_ids"].shape[1]
    gen_ids = out[prompt_len:]              # ONLY NEW TOKENS

    text = tokenizer.decode(gen_ids, skip_special_tokens=True)
    return text.strip()


# Load user queries and run pipeline with random 60% masked / 40% unmasked

with open(USER_QUERY_JSON, "r", encoding="utf-8") as fh:
    user_queries = json.load(fh)

results = []
num_queries = len(user_queries)
num_masked = int(num_queries * 0.6)

shuffled_indices = list(range(num_queries))
random.shuffle(shuffled_indices)
masked_indices = set(shuffled_indices[:num_masked])

for i, rec in enumerate(user_queries,1):
    query_text = rec.get("query","").strip()
    consent_flag = False if (i-1) in masked_indices else True

    entities = extract_entities(query_text)
    entities = mark_negation(entities, query_text)

    if not consent_flag:
        processed_query, mask_map = simple_mask(query_text)
    else:
        processed_query = query_text
        mask_map = {}

    q_vec = embedder.encode([processed_query], convert_to_numpy=True).astype("float32")[0]

    retrieved = retrieve_topk(q_vec, topk=TOPK)

    prompt = build_prompt(processed_query, retrieved)

    answer = generate_from_model(prompt)

    if mask_map:
        for tok, orig in mask_map.items():
            answer = answer.replace(tok, orig)   # FIXED DIRECTION

    rec_id = i
    payload = {
        "id": rec_id,
        "query": query_text,
        "processed_query": processed_query,
        "mask_map": mask_map,
        "entities": entities,
        "retrieved_docs": retrieved,
        "final_answer": answer,
        "consent": consent_flag
    }
    write_encrypted_json(rec_id, payload)
    results.append(payload)
    print(f"[{i}/{len(user_queries)}] Query processed. Consent={consent_flag}. Final answer length: {len(answer)} chars.")

OUT_PATH = os.path.join(BASE_DIR, "user_query_results.json")
with open(OUT_PATH, "w", encoding="utf-8") as fh:
    json.dump(results, fh, indent=2)
print(f"All results saved to {OUT_PATH}")


Device: cuda
Encrypted RAG already exists.
Loaded 334 RAG documents.
Embedding dim: 384
FAISS index built and saved at /content/drive/MyDrive/User_data_rag/rag_faiss.index
Fine-tuned model loaded in 8-bit.


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[1/144] Query processed. Consent=False. Final answer length: 109 chars.
[2/144] Query processed. Consent=False. Final answer length: 122 chars.
[3/144] Query processed. Consent=True. Final answer length: 86 chars.
[4/144] Query processed. Consent=False. Final answer length: 96 chars.
[5/144] Query processed. Consent=False. Final answer length: 121 chars.
[6/144] Query processed. Consent=False. Final answer length: 89 chars.
[7/144] Query processed. Consent=False. Final answer length: 88 chars.
[8/144] Query processed. Consent=False. Final answer length: 94 chars.
[9/144] Query processed. Consent=True. Final answer length: 76 chars.
[10/144] Query processed. Consent=False. Final answer length: 112 chars.
[11/144] Query processed. Consent=False. Final answer length: 81 chars.
[12/144] Query processed. Consent=True. Final answer length: 108 chars.
[13/144] Query processed. Consent=True. Final answer length: 82 chars.
[14/144] Query processed. Consent=False. Final answer length: 117 chars.