In [3]:
hf_token = ""

In [6]:
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch

# Step 1: Initialize the FAISS Vector Store
def initialize_vector_store(data_path, embedding_model_name="sentence-transformers/all-MiniLM-L6-v2"):
    """
    Initialize FAISS vector store from a text dataset.

    Args:
        data_path (str): Path to the text dataset.
        embedding_model_name (str): HuggingFace embedding model name.

    Returns:
        FAISS: Initialized FAISS vector store.
    """
    
    output_file = 'fine_tuning_diseases.txt'

    with open(output_file, "r") as f:
        lines = f.readlines()
    # Split data into individual entries for FAISS
    retriever_data = [{"text": line.strip()} for line in lines]
 

    # Step 2: Build FAISS Index
    embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
    embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name)

    # Initialize FAISS retriever
    vectorstore = FAISS.from_texts(
        texts=[doc["text"] for doc in retriever_data],
        embedding=embedding_model
    )
    return vectorstore

# Step 2: Load LLaMA Model and Tokenizer
def load_llama_model_and_tokenizer(model_path="meta-llama/Llama-2-7b-chat-hf"):
    """
    Load the LLaMA model and tokenizer.

    Args:
        model_path (str): Path to the LLaMA model.

    Returns:
        LlamaForCausalLM, LlamaTokenizer: Loaded model and tokenizer.
    """
    tokenizer = LlamaTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
    return model, tokenizer



In [7]:

# Initialize FAISS vector store
data_path = "fine_tuning_diseases.txt"
vectorstore = initialize_vector_store(data_path)

# Load LLaMA model and tokenizer
model_path = "meta-llama/Llama-2-7b-chat-hf"
model, tokenizer = load_llama_model_and_tokenizer(model_path)


Loading checkpoint shards: 100%|██████████| 2/2 [00:18<00:00,  9.35s/it]
Some parameters are on the meta device because they were offloaded to the disk.


In [12]:
# Step 3: Define the Disease Identification Function
def identify_disease_with_context(symptoms, vectorstore, model, tokenizer, max_length=250, k=1):
    """
    Predicts a disease based on symptoms and retrieved context.

    Args:
        symptoms (str): Input symptoms provided by the user.
        vectorstore (FAISS): FAISS vector store for context retrieval.
        model: LLaMA model.
        tokenizer: Tokenizer for the LLaMA model.
        max_length (int): Maximum length for the generated response.
        k (int): Number of top documents to retrieve for context.

    Returns:
        str: Predicted disease or an appropriate message.
    """
    # Retrieve context from the vector store
    query = f"Symptoms: {symptoms}"
    results = vectorstore.similarity_search(query, k=k)
    context = " ".join([doc.page_content for doc in results]) if results else "No relevant context available."

    # Create input text with retrieved context
    input_text = (
        "You are a medical assistant trained to predict diseases based on symptoms.\n\n"
        f"Context: {context}\n"
        f"Symptoms: {symptoms}\n\n"
        "Based on the context and symptoms, provide the name of the disease. "
        "If the context does not contain the answer, respond with: 'I don't know the answer.'"
    )
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    input_ids = input_ids.to(model.device)
    
    # Generate response
    outputs = model.generate(
        input_ids,
        max_length=min(len(input_ids[0]) + 50, tokenizer.model_max_length),
        top_k=50,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id,
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract diagnosis from the response
    response_lines = response.strip().split("\n")
    for line in response_lines:
        if "this could be the disease:" in line:
            # Extract the disease name after the phrase
            disease_name = line.split("this could be the disease:")[-1].strip()
            if disease_name:
                # Remove any trailing period
                disease_name = disease_name.rstrip(".")
                return disease_name  # Return the extracted disease name
    return response




In [9]:
# Predict disease based on symptoms
user_symptoms = "I have deep, constant pain in my belly and back, and I feel a pulse near my bellybutton."
diagnosis = identify_disease_with_context(user_symptoms, vectorstore, model, tokenizer)

print(f"Predicted Disease: {diagnosis}")

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Predicted Disease: Aneurysm (abdominal aortic)


In [10]:
# Predict disease based on symptoms
user_symptoms = "I’m experiencing a thin, gray vaginal discharge with a fishy odor and itching. What might this indicate?"
diagnosis = identify_disease_with_context(user_symptoms, vectorstore, model, tokenizer)

print(f"Predicted Disease: {diagnosis}")

Predicted Disease: Bacterial vaginosis


In [15]:
import pandas as pd

def evaluate_model(test_file_path, vectorstore, model, tokenizer, k=1):
    """
    Evaluate the model on a test dataset.

    Args:
        test_file_path (str): Path to the CSV file containing the test dataset.
        vectorstore (FAISS): FAISS vector store for context retrieval.
        model: LLaMA model.
        tokenizer: Tokenizer for the LLaMA model.
        k (int): Number of top documents to retrieve for context.

    Returns:
        accuracy (float): Accuracy of the model.
        mismatched_cases (list): List of mismatched cases.
    """
    # Load the test dataset
    test_df = pd.read_csv(test_file_path)

    correct_predictions = 0
    total_predictions = len(test_df)
    mismatched_cases = []

    for _, row in test_df.iterrows():
        # Extract symptoms and actual disease
        symptoms = row["symptoms"]
        actual_disease = row["disease"].strip().lower()

        # Predict disease using the model
        predicted_disease = identify_disease_with_context(symptoms, vectorstore, model, tokenizer, k=k)
        predicted_disease = predicted_disease.lower()

        # Compare predicted and actual diseases
        if predicted_disease == actual_disease:
            correct_predictions += 1
        else:
            mismatched_cases.append({
                "Symptoms": symptoms,
                "Actual Disease": actual_disease,
                "Predicted Disease": predicted_disease
            })

    # Calculate accuracy
    accuracy = correct_predictions / total_predictions
    return accuracy, mismatched_cases

# File path to the test dataset
test_file_path = "symptom_diseases_test.csv"

# Evaluate the model
accuracy, mismatched_cases = evaluate_model(test_file_path, vectorstore, model, tokenizer)

# Display the results
print(f"Accuracy: {accuracy * 100:.2f}%")
if mismatched_cases:
    print("\nMismatched Cases:")
    for case in mismatched_cases:
        print(f"Symptoms: {case['Symptoms']}")
        print(f"Actual Disease: {case['Actual Disease']}")
        print(f"Predicted Disease: {case['Predicted Disease']}")
        print("-" * 50)




Accuracy: 91.30%

Mismatched Cases:
Symptoms: I have deep, constant pain in my belly and back, and I feel a pulse near my bellybutton. What could this be?
Actual Disease: abdominal aortic aneurysm
Predicted Disease: aneurysm (abdominal aortic)
--------------------------------------------------
Symptoms: I’m having trouble swallowing, and sometimes it feels like food is stuck in my throat. I’ve also lost some weight. What might be going on?
Actual Disease: achalasia
Predicted Disease: swallowing problems
--------------------------------------------------
Symptoms: I feel severe pain in my upper right belly that spreads to my shoulder, and I’ve been feeling nauseous and feverish. What could this mean?
Actual Disease: acute cholecystitis
Predicted Disease: cholecystitis (acute)
--------------------------------------------------
Symptoms: I’ve been losing hearing on one side, and there’s ringing in my ear. Sometimes I feel dizzy and off balance. What might this be?
Actual Disease: acoustic