In [1]:
!pip install transformers datasets faiss-gpu faiss-cpu

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m22.7 M

In [3]:
import faiss
import numpy as np
from transformers import AutoTokenizer, TFAutoModel
from datasets import load_dataset
import tensorflow as tf
import json
from tqdm import tqdm

# Dataset and model configuration
DATASET_NAME = "QuyenAnhDE/Diseases_Symptoms"
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
INDEX_FILE = "diseases_symptoms_index.faiss"
EMBEDDINGS_FILE = "diseases_symptoms_embeddings.npy"
DOCS_FILE = "diseases_symptoms_docs.json"

# Set up device
device = tf.device('/GPU:0') if tf.config.list_physical_devices('GPU') else tf.device('/CPU:0')
print(f"Using device: {device}")

# Load the Diseases_Symptoms dataset
print("Loading dataset...")
dataset = load_dataset(DATASET_NAME, split="train")



Using device: <tensorflow.python.eager.context._EagerDeviceContext object at 0x7c7f39d625c0>
Loading dataset...


Repo card metadata block was not found. Setting CardData to empty.


In [4]:
dataset

Dataset({
    features: ['Code', 'Name', 'Symptoms', 'Treatments'],
    num_rows: 400
})

In [6]:

# Preprocess dataset into a format suitable for retrieval
def preprocess_dataset(dataset):
    """Extract and preprocess diseases and symptoms data."""
    documents = []
    for entry in dataset:
        disease = entry["Name"]
        symptoms = entry["Symptoms"]
        documents.append({
            "disease": disease,
            "symptoms": symptoms,
            "combined": f"Disease: {disease}. Symptoms: {symptoms}"
        })
    return documents

print("Processing dataset...")
documents = preprocess_dataset(dataset)

# Save the documents for later use
with open(DOCS_FILE, "w", encoding="utf-8") as f:
    json.dump(documents, f)
print(f"Processed and saved {len(documents)} records.")


Processing dataset...
Processed and saved 400 records.


In [8]:

# Load embedding model
print("Loading embedding model...")
tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME)
model = TFAutoModel.from_pretrained(EMBEDDING_MODEL_NAME)

# Generate embeddings for each document
def generate_embeddings(documents, batch_size=16):
    """Generate embeddings for documents."""
    embeddings = []
    for i in tqdm(range(0, len(documents), batch_size), desc="Generating embeddings"):
        batch_texts = [doc["combined"] for doc in documents[i:i + batch_size]]
        inputs = tokenizer(batch_texts, return_tensors="tf", padding=True, truncation=True, max_length=512)
        with tf.device('/GPU:0'):
            outputs = model(inputs)
        batch_embeddings = tf.reduce_mean(outputs.last_hidden_state, axis=1).numpy()
        embeddings.extend(batch_embeddings)
    return np.array(embeddings)

print("Generating embeddings...")
embeddings = generate_embeddings(documents)

# Save embeddings for reuse
np.save(EMBEDDINGS_FILE, embeddings)
print(f"Saved embeddings for {len(embeddings)} records.")

# Build and save FAISS index
print("Building FAISS index...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype("float32"))
faiss.write_index(index, INDEX_FILE)
print(f"Saved FAISS index to {INDEX_FILE}.")


Loading embedding model...


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['embeddings.position_ids']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


Generating embeddings...


Generating embeddings: 100%|██████████| 25/25 [00:04<00:00,  6.11it/s]

Saved embeddings for 400 records.
Building FAISS index...
Saved FAISS index to diseases_symptoms_index.faiss.





In [9]:

# Load retrieval components
def load_retrieval_components():
    """Load documents, embeddings, and FAISS index."""
    with open(DOCS_FILE, "r", encoding="utf-8") as f:
        docs = json.load(f)
    embeddings = np.load(EMBEDDINGS_FILE)
    index = faiss.read_index(INDEX_FILE)
    return docs, embeddings, index

# Retrieve top-k relevant documents
def retrieve_documents(query, docs, index, top_k=5):
    """Retrieve relevant documents based on a query."""
    inputs = tokenizer([query], return_tensors="tf", padding=True, truncation=True, max_length=512)
    with tf.device('/GPU:0'):
        outputs = model(inputs)
    query_embedding = tf.reduce_mean(outputs.last_hidden_state, axis=1).numpy().astype("float32")
    distances, indices = index.search(query_embedding, top_k)
    return [docs[idx] for idx in indices[0]]

# Test retrieval system
print("Testing retrieval system...")
docs, embeddings, index = load_retrieval_components()

test_query = "What are the symptoms of diabetes?"
retrieved_docs = retrieve_documents(test_query, docs, index)

print(f"Query: {test_query}")
print("\nRetrieved Documents:")
for i, doc in enumerate(retrieved_docs, 1):
    print(f"{i}. Disease: {doc['disease']}")
    print(f"   Symptoms: {doc['symptoms']}")

print("Retrieval system ready.")

Testing retrieval system...
Query: What are the symptoms of diabetes?

Retrieved Documents:
1. Disease: Insulin Overdose
   Symptoms: Low blood sugar (hypoglycemia) symptoms (e.g., confusion, sweating)
2. Disease: Diabetic Kidney Disease
   Symptoms: High blood pressure, increased need to urinate, swelling in the legs and ankles, fatigue, nausea, loss of appetite
3. Disease: Type 2 Diabetes
   Symptoms: Fatigue, Increased hunger, Slow healing of wounds
4. Disease: Hyperosmotic Hyperketotic State
   Symptoms: Increased blood sugar levels, increased ketone production, excessive thirst, frequent urination, fatigue, confusion, fruity breath odor
5. Disease: Chronic Kidney Disease
   Symptoms: Fatigue, swelling of the legs or ankles, decreased appetite, difficulty concentrating, increased urination or urine changes, blood in urine, high blood pressure
Retrieval system ready.
