In [None]:
!pip install transformers sentence-transformers faiss-cpu ipywidgets cryptography sacremoses spacy negspacy bitsandbytes accelerate -q

In [3]:
import os, re, json, time, random, warnings
from typing import List, Dict, Any, Tuple
import numpy as np
import faiss
import torch
import spacy
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from cryptography.fernet import Fernet

DRIVE_BASE = "/content/drive/MyDrive/qwen_finetune"
ARTIFACTS_DIR = os.path.join(DRIVE_BASE, "artifacts")

RAG_TRAIN_JSON = os.path.join(ARTIFACTS_DIR, "rag-train.json")
USER_QUERY_JSON = os.path.join(ARTIFACTS_DIR, "rag-test.json")

FAISS_PATH = os.path.join(ARTIFACTS_DIR, "reports_index.faiss")
META_PATH = os.path.join(ARTIFACTS_DIR, "reports_metadata.json")
FINETUNED_DIR = os.path.join(DRIVE_BASE, "qwen1.5b_qlora_ft_merged")

EMBED_MODEL = "all-MiniLM-L6-v2"
RANDOM_SEED = 42
SIM_THRESHOLD = 0.30
TOPK = 3
MASK_PROB = 0.60
PII_LABELS = {"PERSON", "GPE", "LOC"}

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

ENCRYPTION_DIR = os.path.join(DRIVE_BASE, "encrypted_rag")
os.makedirs(ENCRYPTION_DIR, exist_ok=True)

ENCRYPTED_RAG_PATH = os.path.join(ENCRYPTION_DIR, "rag_data.json.enc")
FERNET_KEY_PATH = os.path.join(ENCRYPTION_DIR, "fernet.key")
RESULTS_PATH = os.path.join(ENCRYPTION_DIR, "user_query_results.json")

def get_or_create_fernet_key(path: str = FERNET_KEY_PATH) -> bytes:
    if os.path.exists(path):
        return open(path,"rb").read()
    key = Fernet.generate_key()
    open(path,"wb").write(key)
    return key

FERNET_KEY = get_or_create_fernet_key()
FERNET = Fernet(FERNET_KEY)

def encrypt_bytes(b: bytes) -> bytes:
    return FERNET.encrypt(b)

def decrypt_bytes(b: bytes) -> bytes:
    return FERNET.decrypt(b)

def write_encrypted_json(id_int: int, payload: Dict[str,Any]) -> str:
    path = os.path.join(ENCRYPTION_DIR, f"{id_int}.json.enc")
    raw = json.dumps(payload, default=str).encode("utf-8")
    cipher = encrypt_bytes(raw)
    with open(path,"wb") as fh:
        fh.write(cipher)
    return path

try:
    nlp = spacy.load("en_core_web_sm")
    print("Loaded spaCy NER model.")
except Exception as e:
    print("spaCy model not found, downloading...")
    import subprocess
    subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
    nlp = spacy.load("en_core_web_sm")
    print("Loaded spaCy NER model.")

def simple_regex_mask(text: str) -> str:
    out = text
    out = re.sub(r"\b[A-Z][a-z]+ [A-Z][a-z]+\b", "[NAME]", out)
    out = re.sub(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+", "[EMAIL]", out)
    out = re.sub(r"\b(\+?\d{1,3}[-.\s]?)?(\d{3}[-.\s]?\d{3}[-.\s]?\d{4})\b", "[PHONE]", out)
    return out

def mask_with_spacy(text: str) -> Tuple[str, Dict[str,str]]:
    doc = nlp(text)
    mask_map = {}
    masked = text

    entities = sorted(doc.ents, key=lambda x: x.start_char, reverse=True)

    for ent in entities:
        if ent.label_ in PII_LABELS:
            placeholder = f"<{ent.label_}_{len(mask_map)+1}>"
            mask_map[placeholder] = ent.text
            masked = masked[:ent.start_char] + placeholder + masked[ent.end_char:]

    masked = simple_regex_mask(masked)
    return masked, mask_map

print("Loading sentence-transformers embedder...")
embedder = SentenceTransformer(EMBED_MODEL, device=device)

print("Loading FAISS index and metadata...")
index = faiss.read_index(FAISS_PATH)
with open(META_PATH, "r", encoding="utf-8") as fh:
    meta = json.load(fh)

def retrieve_topk_query_pairs(current_query: str, current_idx: int = -1, topk: int = TOPK, sim_threshold: float = SIM_THRESHOLD) -> List[Dict[str,Any]]:
    fetch_k = max(10, topk * 4)

    q_emb = embedder.encode([current_query], convert_to_numpy=True).astype("float32")
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, fetch_k)

    hits = []
    for score, idx in zip(D[0], I[0]):
        if idx < 0 or idx >= len(meta):
            continue
        if float(score) < sim_threshold:
            continue

        if current_idx != -1 and idx == current_idx:
            continue

        rec = meta[int(idx)]
        hits.append({
            "similar_query": rec["query"],
            "report": rec["report"],
            "score": float(score),
            "orig_idx": int(idx)
        })

        if len(hits) >= topk:
            break

    return hits

def build_prompt(query: str, retrieved_pairs: List[Dict[str,Any]]) -> str:
    prompt_lines = []
    prompt_lines.append("[QUERY]")
    prompt_lines.append(query.strip())
    prompt_lines.append("")
    prompt_lines.append("[SIMILAR_CASES]")
    for i, pair in enumerate(retrieved_pairs, start=1):
        prompt_lines.append(f"Case {i}:")
        prompt_lines.append(f"Query: {pair['similar_query']}")
        prompt_lines.append(f"Report: {pair['report']}")
        prompt_lines.append("")
    prompt_lines.append("[INSTRUCTION]")
    prompt_lines.append("Generate a clinical report. Do NOT include disclaimers.")
    prompt = "\n".join(prompt_lines)
    return prompt

print("Loading fine-tuned Qwen model...")

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B", trust_remote_code=True)

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token":"<|pad|>"})

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False
)

model = AutoModelForCausalLM.from_pretrained(
    FINETUNED_DIR,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
print("Fine-tuned Qwen model loaded!")

def generate_from_model(prompt: str, max_tokens: int = 256) -> str:
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)

    out = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
        temperature=0.1
    )[0]

    prompt_len = inputs["input_ids"].shape[1]
    gen_ids = out[prompt_len:]
    text = tokenizer.decode(gen_ids, skip_special_tokens=True)
    return text.strip()

print("Loading user queries...")
with open(USER_QUERY_JSON, "r", encoding="utf-8") as fh:
    user_queries = json.load(fh)

print(f"Processing {len(user_queries)} user queries...")

results = []
num_queries = len(user_queries)
num_masked = int(num_queries * MASK_PROB)

shuffled_indices = list(range(num_queries))
random.shuffle(shuffled_indices)
masked_indices = set(shuffled_indices[:num_masked])

for i, rec in enumerate(user_queries, 1):
    original_query = rec.get("query", "").strip()
    report = rec.get("report", "")

    consent_flag = False if (i - 1) in masked_indices else True

    if not consent_flag:
        processed_query, mask_map = mask_with_spacy(original_query)
        print(f"Query {i}: MASKED")
        print(f"  Original: {original_query}")
        print(f"  Masked: {processed_query}")
        print(f"  Mask map: {mask_map}")
    else:
        processed_query = original_query
        mask_map = {}
        print(f"Query {i}: UNMASKED")
        print(f"  Query: {original_query}")

    hits = retrieve_topk_query_pairs(processed_query, current_idx=-1, topk=TOPK, sim_threshold=SIM_THRESHOLD)

    prompt = build_prompt(processed_query, hits)

    answer = generate_from_model(prompt)

    if mask_map:
        for placeholder, original_text in mask_map.items():
            answer = answer.replace(placeholder, original_text)

    payload = {
        "id": i,
        "original_query": original_query,
        "processed_query": processed_query,
        "mask_map": mask_map,
        "retrieved_count": len(hits),
        "retrieved_cases": [{"query": hit["similar_query"], "report": hit["report"]} for hit in hits],
        "final_answer": answer,
        "consent_given": consent_flag
    }

    write_encrypted_json(i, payload)
    results.append(payload)

    print(f"[{i}/{len(user_queries)}] Generated answer: {len(answer)} chars")
    print("---")

with open(RESULTS_PATH, "w", encoding="utf-8") as fh:
    json.dump(results, fh, indent=2)

print(f"Pipeline completed! Processed {len(user_queries)} queries")
print(f"{num_masked}/{len(user_queries)} queries masked ({MASK_PROB*100}%)")
print(f"Individual results encrypted in {ENCRYPTION_DIR}")

Device: cuda
Loaded spaCy NER model.
Loading sentence-transformers embedder...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Loading FAISS index and metadata...
Loading fine-tuned Qwen model...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]



Fine-tuned Qwen model loaded!
Loading user queries...
Processing 96 user queries...
Query 1: MASKED
  Original: My name is Virgil, 72. I have <PERSON_1>'s disease. What medical therapy?
  Masked: My name is Virgil, 72. I have <PERSON_1>'s disease. What medical therapy?
  Mask map: {}


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[1/96] Generated answer: 86 chars
---
Query 2: MASKED
  Original: My name is <GPE_1>, 30. I have chlamydia. What antibiotic?
  Masked: My name is <GPE_1>, 30. I have chlamydia. What antibiotic?
  Mask map: {}
[2/96] Generated answer: 80 chars
---
Query 3: UNMASKED
  Query: I'm <PERSON_2>, 59. My <PERSON_1> for infection is expensive. Alternative?
[3/96] Generated answer: 115 chars
---
Query 4: MASKED
  Original: My name is <PERSON_1>, 40. I have cold. [NAME] safe?
  Masked: My name is <PERSON_1>, 40. I have cold. [NAME] safe?
  Mask map: {}
[4/96] Generated answer: 106 chars
---
Query 5: MASKED
  Original: My name is <PERSON_1>, 34. I have razor burn. Is hydrocortisone safe?
  Masked: My name is <PERSON_1>, 34. I have razor burn. Is hydrocortisone safe?
  Mask map: {}
[5/96] Generated answer: 56 chars
---
Query 6: MASKED
  Original: I'm <PERSON_1>, 64. [NAME] for PAH is unaffordable. Options?
  Masked: I'm <PERSON_1>, 64. [NAME] for PAH is unaffordable. Options?
  Mask map: {}
[6/96] G