## Notebook with finetuning of a BERT Medical model

In [1]:
import kagglehub
from kagglehub import KaggleDatasetAdapter


file_path = "DiseaseAndSymptoms.csv"

symptoms_df = kagglehub.load_dataset(
  KaggleDatasetAdapter.PANDAS,
  "choongqianzheng/disease-and-symptoms-dataset",
  file_path,
)

print("First 5 records:", symptoms_df.head())

  from .autonotebook import tqdm as notebook_tqdm
  symptoms_df = kagglehub.load_dataset(


First 5 records:             Disease   Symptom_1              Symptom_2              Symptom_3  \
0  Fungal infection     itching              skin_rash   nodal_skin_eruptions   
1  Fungal infection   skin_rash   nodal_skin_eruptions    dischromic _patches   
2  Fungal infection     itching   nodal_skin_eruptions    dischromic _patches   
3  Fungal infection     itching              skin_rash    dischromic _patches   
4  Fungal infection     itching              skin_rash   nodal_skin_eruptions   

              Symptom_4 Symptom_5 Symptom_6 Symptom_7 Symptom_8 Symptom_9  \
0   dischromic _patches       NaN       NaN       NaN       NaN       NaN   
1                   NaN       NaN       NaN       NaN       NaN       NaN   
2                   NaN       NaN       NaN       NaN       NaN       NaN   
3                   NaN       NaN       NaN       NaN       NaN       NaN   
4                   NaN       NaN       NaN       NaN       NaN       NaN   

  Symptom_10 Symptom_11 Symptom_1

In [2]:
file_path = "Disease precaution.csv"

precaution_df = kagglehub.load_dataset(
  KaggleDatasetAdapter.PANDAS,
  "choongqianzheng/disease-and-symptoms-dataset",
  file_path,
)

print("First 5 records:", precaution_df.head())

  precaution_df = kagglehub.load_dataset(


First 5 records:           Disease                      Precaution_1  \
0   Drug Reaction                   stop irritation   
1         Malaria          Consult nearest hospital   
2         Allergy                    apply calamine   
3  Hypothyroidism                     reduce stress   
4       Psoriasis  wash hands with warm soapy water   

                   Precaution_2        Precaution_3  \
0      consult nearest hospital    stop taking drug   
1               avoid oily food  avoid non veg food   
2       cover area with bandage                 NaN   
3                      exercise         eat healthy   
4  stop bleeding using pressure      consult doctor   

                  Precaution_4  
0                    follow up  
1           keep mosquitos out  
2  use ice to compress itching  
3             get proper sleep  
4                   salt baths  


In [3]:
import pandas as pd

canonical_descriptions = []
grouped = symptoms_df.groupby('Disease')
for disease, group in grouped:
    symptoms = ', '.join(
        sorted(set(
            symptom
            for i in range(1, 18)
            for symptom in group[f'Symptom_{i}']
            if pd.notna(symptom)
        ))
    )
    description = f"{disease} is characterized by the following symptoms: {symptoms}."
    canonical_descriptions.append({'Disease': disease, 'Canonical Description': description})

canonical_df = pd.DataFrame(canonical_descriptions)

print(canonical_df.head())

                                   Disease  \
0  (vertigo) Paroymsal  Positional Vertigo   
1                                     AIDS   
2                                     Acne   
3                      Alcoholic hepatitis   
4                                  Allergy   

                               Canonical Description  
0  (vertigo) Paroymsal  Positional Vertigo is cha...  
1  AIDS is characterized by the following symptom...  
2  Acne is characterized by the following symptom...  
3  Alcoholic hepatitis is characterized by the fo...  
4  Allergy is characterized by the following symp...  


### Phase 1 – Build the Knowledge Base
We consolidate every disease into a single canonical description, encode those descriptions with a sentence-transformer to obtain dense vectors, and store the vectors inside a FAISS index so we can later retrieve the closest diseases for any symptom query.

In [13]:
import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Sequence
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss

@dataclass
class RetrievedDisease:
    disease: str
    description: str
    similarity: float

class DiseaseKnowledgeBase:
    """Builds and queries a FAISS vector index over canonical disease descriptions."""
    def __init__(self, embedding_model: str = "sentence-transformers/all-MiniLM-L12-v2"):
        self.embedding_model_name = embedding_model
        self.embedder = SentenceTransformer(self.embedding_model_name)
        self.index = None
        self.embeddings = None
        self.diseases: List[str] = []
        self.descriptions: List[str] = []

    def build(self, descriptions: Sequence[str], diseases: Sequence[str]):
        self.descriptions = list(descriptions)
        self.diseases = list(diseases)
        self.embeddings = self.embedder.encode(
            self.descriptions, convert_to_numpy=True, normalize_embeddings=True
        )
        dim = self.embeddings.shape[1]
        self.index = faiss.IndexFlatIP(dim)
        self.index.add(self.embeddings)
        print(f"Indexed {self.index.ntotal} diseases (dim={dim}).")
        return self

    def persist(self, output_dir: Path):
        if self.index is None or self.embeddings is None:
            raise ValueError("Call build() before persisting the index.")
        output_dir.mkdir(parents=True, exist_ok=True)
        faiss.write_index(self.index, str(output_dir / "faiss.index"))
        np.save(output_dir / "embeddings.npy", self.embeddings)
        metadata = {
            "embedding_model": self.embedding_model_name,
            "diseases": self.diseases,
            "descriptions": self.descriptions
        }
        with open(output_dir / "metadata.json", "w", encoding="utf-8") as f:
            json.dump(metadata, f, ensure_ascii=False, indent=2)
        print(f"Persisted vector store to {output_dir.resolve()}")

    def search(self, query: str, top_k: int = 5) -> List[RetrievedDisease]:
        if self.index is None:
            raise ValueError("Index not built. Call build() before search().")
        query_embedding = self.embedder.encode(
            [query], convert_to_numpy=True, normalize_embeddings=True
        )
        distances, indices = self.index.search(query_embedding, top_k)
        results: List[RetrievedDisease] = []
        for score, idx in zip(distances[0], indices[0]):
            results.append(
                RetrievedDisease(
                    disease=self.diseases[idx],
                    description=self.descriptions[idx],
                    similarity=float(score)
                )
            )
        return results

In [14]:
kb = DiseaseKnowledgeBase()
kb.build(
    descriptions=canonical_df['Canonical Description'],
    diseases=canonical_df['Disease']
 )
kb.persist(Path("../models/vector_knowledge_base"))

Indexed 41 diseases (dim=384).
Persisted vector store to /home/adam/projects/medical_diagnose_ml/ml/models/vector_knowledge_base


In [15]:
def evaluate_retrieval(kb: DiseaseKnowledgeBase, top_k: int = 5):
    queries = canonical_df['Canonical Description'].tolist()
    labels = canonical_df['Disease'].tolist()
    correct_at_1 = 0
    correct_at_k = 0
    for text, label in zip(queries, labels):
        hits = kb.search(text, top_k=top_k)
        if hits and hits[0].disease == label:
            correct_at_1 += 1
        if any(hit.disease == label for hit in hits):
            correct_at_k += 1
    total = len(labels)
    print(f"Top-1 accuracy: {correct_at_1/total:.2%}")
    print(f"Top-{top_k} accuracy: {correct_at_k/total:.2%}")

evaluate_retrieval(kb, top_k=5)

Top-1 accuracy: 100.00%
Top-5 accuracy: 100.00%


In [16]:
sample_queries = [
    "Patient reports acne-like skin lesions, papules, and inflammation.",
    "Patient complains about yellowing of skin, dark urine, and loss of appetite.",
    "Patient experiences high fever, chills, and severe body aches.",
 ]

for query in sample_queries:
    print(f"Query: {query}")
    hits = kb.search(query, top_k=3)
    for hit in hits:
        print(f"  {hit.disease} (similarity={hit.similarity:.4f})")
    print()

Query: Patient reports acne-like skin lesions, papules, and inflammation.
  Acne (similarity=0.5848)
  Psoriasis (similarity=0.4453)
  Fungal infection (similarity=0.3952)

Query: Patient complains about yellowing of skin, dark urine, and loss of appetite.
  Hepatitis D (similarity=0.4740)
  Hepatitis C (similarity=0.4670)
  Hepatitis B (similarity=0.4561)

Query: Patient experiences high fever, chills, and severe body aches.
  Pneumonia (similarity=0.4764)
  Common Cold (similarity=0.4465)
  Tuberculosis (similarity=0.4457)



**How to use this knowledge base**
1. Refresh `canonical_df` with the data-loading cells.
2. Run the class definition cell, then execute the build/persist cell (`kb = DiseaseKnowledgeBase(); kb.build(...); kb.persist(...)`).
3. Use `kb.search("symptom description", top_k)` in any cell or reuse the provided demo cell to retrieve the closest diseases.

### Phase 2 – Interactive Advisor
Encapsulate retrieval + precaution lookup so a single method call returns the top suspected diseases together with their recommended precautions.

In [None]:
from typing import Dict
import pandas as pd


class DiseaseAdvisor:
    """Returns likely diseases plus their precautions for a user query."""
    def __init__(self, knowledge_base: DiseaseKnowledgeBase, precaution_df: pd.DataFrame, default_top_k: int = 2):
        self.kb = knowledge_base
        self.default_top_k = default_top_k
        self.precaution_lookup = self._build_precaution_lookup(precaution_df)

    @staticmethod
    def _build_precaution_lookup(df: pd.DataFrame) -> Dict[str, list]:
        precaution_cols = [col for col in df.columns if col.lower().startswith("precaution")]
        lookup: Dict[str, list] = {}
        for _, row in df.iterrows():
            disease = row['Disease']
            steps = [str(row[col]).strip() for col in precaution_cols if pd.notna(row[col])]
            if steps:
                lookup[disease] = steps
        return lookup

    def advise(self, query: str, top_k: int | None = None):
        k = top_k or self.default_top_k
        hits = self.kb.search(query, top_k=k)
        results = []
        for hit in hits:
            results.append({
                "disease": hit.disease,
                "similarity": hit.similarity,
                "description": hit.description,
                "precautions": self.precaution_lookup.get(hit.disease, [])
            })
        return results

In [None]:
advisor = DiseaseAdvisor(kb, precaution_df, default_top_k=2)

demo_queries = [
    "Patient reports acne-like skin lesions, papules, and inflammation.",
    "Patient complains about yellowing of skin, dark urine, and loss of appetite.",
    "Patient experiences high fever, chills, and severe body aches.",
]

In [None]:
for query in demo_queries:
    print(f"Query: {query}")
    recommendations = advisor.advise(query)
    for rec in recommendations:
        print(f"  Disease: {rec['disease']} (similarity={rec['similarity']:.4f})")
        if rec['precautions']:
            for precaution in rec['precautions']:
                print(f"    - {precaution}")
        else:
            print("    - No precautions listed in dataset.")
    print()