# Medical Vector Database (TF-IDF Version)

This notebook demonstrates:
- Normalized relational schema
- DuckDB integration
- TF-IDF embeddings (no external downloads)
- SQL-registered embedding and cosine functions
- Fully SQL-driven similarity search


In [1]:
import duckdb
import pandas as pd
import numpy as np
import re
from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer()

## Load CSV

In [2]:
csv_path = '../data/DerivedKnowledgeGraph_final.csv'
df = pd.read_csv(csv_path)
df.head()

Unnamed: 0,Diseases,Symptoms
0,abscess,"pain (0.318), fever (0.119), swelling (0.112),..."
1,acid reflux,"pain (0.225), nausea (0.140), pain in upper ab..."
2,acute renal failure,"kidney failure (0.127), weakness (0.087), diar..."
3,alcohol intoxication,"slurred speech (0.071), vomiting (0.064), unst..."
4,alcoholism,"seizures (0.125), depression (0.090), sadness ..."


## Parse Symptoms

In [3]:
def parse_symptoms(symptom_string):
    pattern = r"([^,]+?)\s*\(([\d\.]+)\)"
    matches = re.findall(pattern, symptom_string)
    return [(m[0].strip().lower(), float(m[1])) for m in matches]

structured_data = []
for _, row in df.iterrows():
    disease = row.iloc[0].strip().lower()
    symptoms = parse_symptoms(row.iloc[1])
    structured_data.append((disease, symptoms))

structured_data[:2]

[('abscess',
  [('pain', 0.318),
   ('fever', 0.119),
   ('swelling', 0.112),
   ('redness', 0.094),
   ('chills', 0.092),
   ('infection', 0.083),
   ('cyst', 0.047),
   ('tenderness', 0.037),
   ('rectal pain', 0.026),
   ('lesion', 0.025),
   ('lump', 0.023),
   ('sore throat', 0.021),
   ('facial swelling', 0.016),
   ('pimple', 0.016),
   ('discomfort', 0.014),
   ('difficulty swallowing', 0.013),
   ('cavity', 0.013),
   ('night sweats', 0.007),
   ('severe pain', 0.007),
   ('abdominal pain', 0.007),
   ('painful swallowing', 0.007),
   ('back pain', 0.006)]),
 ('acid reflux',
  [('pain', 0.225),
   ('nausea', 0.14),
   ('pain in upper abdomen', 0.064),
   ('abdominal pain', 0.058),
   ('sadness', 0.052),
   ('depression', 0.052),
   ('vomiting', 0.048),
   ('anxiety', 0.045),
   ('diarrhea', 0.044),
   ('discomfort', 0.034),
   ('indigestion', 0.03),
   ('chills', 0.022),
   ('constipation', 0.021),
   ('dizziness', 0.021),
   ('heartburn', 0.02),
   ('chest pressure', 0.017),


## Initialize DuckDB

In [4]:
con = duckdb.connect('medical.db')

con.execute("CREATE SEQUENCE IF NOT EXISTS disease_seq START 1;")
con.execute("CREATE SEQUENCE IF NOT EXISTS symptom_seq START 1;")

con.execute("""
CREATE TABLE IF NOT EXISTS disease (
    disease_id INTEGER PRIMARY KEY DEFAULT nextval('disease_seq'),
    name TEXT UNIQUE
);
""")

con.execute("""
CREATE TABLE IF NOT EXISTS symptom (
    symptom_id INTEGER PRIMARY KEY DEFAULT nextval('symptom_seq'),
    name TEXT UNIQUE
);
""")

con.execute("""
CREATE TABLE IF NOT EXISTS disease_symptom (
    disease_id INTEGER,
    symptom_id INTEGER,
    incidence FLOAT,
    PRIMARY KEY (disease_id, symptom_id)
);
""")

con.execute("""
CREATE TABLE IF NOT EXISTS disease_embedding (
    disease_id INTEGER,
    embedding DOUBLE[]
);
""")

<_duckdb.DuckDBPyConnection at 0x7d4871b67170>

## Populate Tables

In [5]:
for disease, symptoms in structured_data:
    con.execute("INSERT OR IGNORE INTO disease (name) VALUES (?)", [disease])
    disease_id = con.execute(
        "SELECT disease_id FROM disease WHERE name = ?",
        [disease]
    ).fetchone()[0]

    for symptom, incidence in symptoms:
        con.execute("INSERT OR IGNORE INTO symptom (name) VALUES (?)", [symptom])
        symptom_id = con.execute(
            "SELECT symptom_id FROM symptom WHERE name = ?",
            [symptom]
        ).fetchone()[0]

        con.execute("""
            INSERT OR IGNORE INTO disease_symptom
            VALUES (?, ?, ?)
        """, [disease_id, symptom_id, incidence])

## Compute Disease Embeddings (TF-IDF, Fully Offline)

In [6]:
diseases = con.execute("SELECT disease_id, name FROM disease").fetchall()

disease_texts = []
disease_ids = []

for disease_id, name in diseases:
    symptoms = con.execute("""
        SELECT s.name, ds.incidence
        FROM disease_symptom ds
        JOIN symptom s ON ds.symptom_id = s.symptom_id
        WHERE ds.disease_id = ?
    """, [disease_id]).fetchall()

    weighted_text = []
    for symptom_name, incidence in symptoms:
        repetitions = max(1, int(incidence * 10))
        weighted_text.extend([symptom_name] * repetitions)

    disease_texts.append(" ".join(weighted_text))
    disease_ids.append(disease_id)

tfidf_matrix = vectorizer.fit_transform(disease_texts)

for i, disease_id in enumerate(disease_ids):
    vector = tfidf_matrix[i].toarray()[0]
    con.execute(
        "INSERT INTO disease_embedding VALUES (?, ?)",
        [disease_id, vector.tolist()]
    )

print("Embeddings computed successfully (offline).")

Embeddings computed successfully (offline).


## Register SQL Functions

In [7]:
def embed_symptoms(symptom_list):
    if symptom_list is None or len(symptom_list) == 0:
        return None
    query_text = " ".join(symptom_list)
    return vectorizer.transform([query_text]).toarray()[0].tolist()

def cosine(v1, v2):
    v1 = np.array(v1)
    v2 = np.array(v2)
    if np.linalg.norm(v1) == 0 or np.linalg.norm(v2) == 0:
        return 0.0
    return float(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)))

con.create_function("embed_symptoms", embed_symptoms, return_type="DOUBLE[]")
con.create_function("cosine", cosine, return_type="DOUBLE")

<_duckdb.DuckDBPyConnection at 0x7d4871b67170>

## SQL-Driven Helper Functions

In [8]:
def disease_similarity(d1, d2):
    query = f"""
    SELECT
        cosine(e1.embedding, e2.embedding) AS similarity
    FROM disease_embedding e1, disease d1, disease_embedding e2, disease d2
    WHERE d1.disease_id = e1.disease_id
          AND d2.disease_id = e2.disease_id
          AND d1.name = '{d1}'
          AND d2.name = '{d2}';
    """
    return con.execute(query).fetchone()[0]

# Example:
print("Disease similarity of colon cancer and liver cancer: ", disease_similarity('colon cancer', 'liver cancer'))
print("Disease similarity of cirrhosis of the liver and liver cancer: ", disease_similarity('cirrhosis of the liver', 'liver cancer'))
print("Disease similarity of common cold and liver cancer: ", disease_similarity('common cold', 'liver cancer'))

Disease similarity of colon cancer and liver cancer:  0.3124396075604429
Disease similarity of cirrhosis of the liver and liver cancer:  0.5681451154620615
Disease similarity of common cold and liver cancer:  0.1154716329368155


In [9]:
def rank_diseases(symptom_list):
    symptoms_sql = ", ".join([f"'{s}'" for s in symptom_list])
    query = f"""
    SELECT d.name,
           cosine(
               embed_symptoms([{symptoms_sql}]),
               de.embedding
            ) AS similarity
    FROM disease d, disease_embedding de
    WHERE d.disease_id = de.disease_id
          AND similarity > 0
    ORDER BY similarity DESC;
    """
    return con.execute(query).fetchall()

# Example:
rank_diseases(['fever', 'cough'])

[('lung cancer', 0.40553979272363117),
 ('neutropenia', 0.39178540168096154),
 ('chronic obstructive pulmonary disease', 0.37256866563078134),
 ('upper respiratory infection', 0.3390834449914737),
 ('pleural effusion', 0.32891418898032254),
 ('pneumonia', 0.32570797767656867),
 ('lupus', 0.3197374424421778),
 ('bronchitis', 0.3196381611445093),
 ('congestive heart failure', 0.29887658877884565),
 ('hiv/aids', 0.29217915783811077),
 ('common cold', 0.27047643256789866),
 ('asthma', 0.26097873634139335),
 ('multiple myeloma', 0.25555970503955755),
 ('hepatitis b', 0.23816205496845264),
 ('sleep apnea', 0.23273754682543676),
 ('meningitis', 0.2302834615812488),
 ('pulmonary hypertension', 0.21849552591063434),
 ('sinusitis', 0.21780861331835244),
 ('cardiomyopathy', 0.21670881242534837),
 ('chronic kidney disease', 0.2106614002520254),
 ('seasonal allergies', 0.20262351137745313),
 ('psoriasis', 0.1915666812764938),
 ('breast cancer', 0.15245731127893775),
 ("parkinson's disease", 0.14834

In [10]:
def rank_similar_diseases(disease_name, top_k=5):
    query = f"""
    SELECT
        d.name,
        cosine(
            (
                SELECT de.embedding
                FROM disease_embedding de
                JOIN disease d2 ON d2.disease_id = de.disease_id
                WHERE d2.name = '{disease_name}'
            ),
            de.embedding
        ) AS similarity
    FROM disease d, disease_embedding de
    WHERE d.disease_id = de.disease_id
    ORDER BY similarity DESC
    LIMIT {top_k};
    """
    return con.execute(query).fetchall()

# Example:
rank_similar_diseases('common cold')

[('common cold', 1.0000000000000002),
 ('seasonal allergies', 0.599995427045222),
 ('sinusitis', 0.5734005892130247),
 ('upper respiratory infection', 0.5328000202062183),
 ('bronchitis', 0.4457861797652012)]