<a href="https://colab.research.google.com/github/xtchen64/virtual-doctor-chatbot/blob/main/notebooks/embedding_based_retrieval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Disease Diagnosis via Embedding Based Retrieval
Written By: Qingyang Xu  
Date Created: 11/24/2023  
Last Modified: 12/30/2023  

### Overview

- Section 1. Test embedding-based retrieval tool `DONE`

  - Faiss: https://github.com/facebookresearch/faiss

- Section 2. Test Word2Vec using BERT `DONE`

  - BERT: https://huggingface.co/transformers/v3.0.2/installation.html

- Section 3. Diagnosis via symptom embedding retrieval with ClinicalBERT `DONE`

  - ClinicalBERT: https://huggingface.co/medicalai/ClinicalBERT

## Section 1. Test embedding-based retrieval tool `DONE`

- Reproduce sanity check results in https://github.com/facebookresearch/faiss/wiki/Getting-started

In [None]:
### Faiss
!pip install faiss-gpu

In [None]:
import numpy as np

d = 64                           # dimension
nb = 100000                      # database size
nq = 10000                       # nb of queries
np.random.seed(1234)             # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xb[:, 0] += np.arange(nb) / 1000.
xq = np.random.random((nq, d)).astype('float32')
xq[:, 0] += np.arange(nq) / 1000.

In [None]:
import faiss                   # make faiss available

index = faiss.IndexFlatL2(d)   # build the index
print(index.is_trained)
index.add(xb)                  # add vectors to the index
print(index.ntotal)

In [None]:
k = 4                          # we want to see 4 nearest neighbors
D, I = index.search(xb[:5], k) # sanity check
print(I)
print(D)
D, I = index.search(xq, k)     # actual search
print("neighbors of the 5 first queries")
print(I[:5])                   # neighbors of the 5 first queries

print("neighbors of the 5 last queries")
print(I[-5:])                  # neighbors of the 5 last queries

## Section 2. Test Word2Vec using BERT `DONE`

In [None]:
!pip install transformers

In [None]:
from transformers import BertTokenizer, BertModel
import torch

# Load pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# Input text for inference
input_text = "I have a high fever."

# Tokenize and encode the input text
tokens = tokenizer.encode(input_text, add_special_tokens=True)
input_ids = torch.tensor(tokens).unsqueeze(0)  # Add batch dimension

In [None]:
print(f"tokens: {tokens}")

In [None]:
# Forward pass through the BERT model
with torch.no_grad():
    outputs = model(input_ids)

In [None]:
# Get the embeddings or logits from the BERT model
last_hidden_states = outputs.last_hidden_state

# For classification tasks, you might use the pooler output
pooler_output = outputs.pooler_output

# Convert PyTorch tensor to numpy array for further processing if needed
numpy_output = last_hidden_states.numpy()

# Print or use the results as needed
print("Embeddings shape:", last_hidden_states.shape)
print("Pooler output shape:", pooler_output.shape)

## Section 3. Diagnosis via symptom embedding retrieval with ClinicalBERT `DONE`

- ClinicalBERT: Modeling Clinical Notes and Predicting Hospital Readmission

### 3.1 Test each module



In [None]:
!pip install transformers

In [None]:
import numpy as np
import torch
from transformers import BertModel, BertTokenizer
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

In [None]:
bert_model = "medicalai/ClinicalBERT"

if bert_model == "medicalai/ClinicalBERT":
  # Load pre-trained ClinicalBERT tokenizer and model
  tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
  model = AutoModel.from_pretrained("medicalai/ClinicalBERT")
elif bert_model == "bert-base-uncased":
  # Load pre-trained model tokenizer and model
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  model = BertModel.from_pretrained('bert-base-uncased')
else:
  raise ValueError(f"Please input a valid BERT model ID {bert_model}")

# Function to encode text to embeddings
def encode_text(text):
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
      outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1)

In [None]:
# List of medical symptoms
symptoms = ["shortness of breath", "fever", "cough", "fatigue", "sleeplessness"]
# Target symptom
target_symptom = "insomnia"

# Encode the symptoms and the target symptom
symptom_embeddings = torch.stack([encode_text(symptom) for symptom in symptoms])
target_embedding = encode_text(target_symptom)

# Check the shapes of the embeddings
print(f"Symptom Embeddings Shape: {symptom_embeddings.shape}")
print(f"Target Embedding Shape: {target_embedding.shape}")

# Calculate cosine similarities
cosine_similarities = F.cosine_similarity(symptom_embeddings.squeeze(), target_embedding)

# Debugging information
print(f"Cosine Similarities: {cosine_similarities}")

# Find the best match
best_match_score = torch.max(cosine_similarities)
print(f"Best Match Score: {best_match_score}")

# Find the best match index
best_match_index = torch.argmax(cosine_similarities).item()
print(f"Best Match Index: {best_match_index}")

best_match_symptom = symptoms[best_match_index]
print(f"The symptom most similar to '{target_symptom}' is '{best_match_symptom}'.")

In [None]:
diseases = {
    "Panic disorder": ['Anxiety and nervousness', 'Depression', 'Shortness of breath', 'Depressive or psychotic symptoms', 'Sharp chest pain', 'Dizziness', 'Insomnia', 'Abnormal involuntary movements', 'Chest tightness', 'Palpitations', 'Irregular heartbeat', 'Breathing fast'],
    "Vocal cord polyp": ['Hoarse voice', 'Sore throat', 'Difficulty speaking', 'Cough', 'Nasal congestion', 'Throat swelling', 'Diminished hearing', 'Lump in throat', 'Throat feels tight', 'Difficulty in swallowing', 'Skin swelling', 'Retention of urine'],
    "Turner syndrome": ['Groin mass', 'Leg pain', 'Hip pain', 'Suprapubic pain', 'Blood in stool', 'Lack of growth', 'Diminished hearing', 'Depression', 'Emotional symptoms', 'Elbow weakness', 'Back weakness', 'Pus in sputum'],
    "Cryptorchidism": ['Symptoms of the scrotum and testes', 'Swelling of scrotum', 'Pain in testicles', 'Flatulence', 'Pus draining from ear', 'Jaundice', 'Mass in scrotum', 'Lack of growth', 'White discharge from eye', 'Irritable infant'],
    "Poisoning due to ethylene glycol": ['Abusing alcohol', 'Fainting', 'Hostile behavior', 'Drug abuse', 'Depressive or psychotic symptoms', 'Sharp abdominal pain', 'Feeling ill', 'Vomiting', 'Headache', 'Depression', 'Nausea', 'Diarrhea'],
    "Atrophic vaginitis": ['Vaginal itching', 'Vaginal dryness', 'Painful urination', 'Involuntary urination', 'Pain during intercourse', 'Frequent urination', 'Lower abdominal pain', 'Suprapubic pain', 'Vaginal discharge', 'Blood in urine', 'Hot flashes', 'Intermenstrual bleeding'],
    "Fracture of the hand": ['Hand or finger pain', 'Wrist pain', 'Hand or finger swelling', 'Arm pain', 'Wrist swelling', 'Arm stiffness or tightness', 'Arm swelling', 'Hand or finger stiffness or tightness', 'Wrist stiffness or tightness'],
    "Cellulitis or abscess of mouth": ['Lip swelling', 'Sore throat', 'Toothache', 'Abnormal appearing skin', 'Skin lesion', 'Difficulty in swallowing', 'Acne or pimples', 'Dry lips', 'Facial pain', 'Mouth ulcer', 'Throat swelling', 'Skin growth'],
}

In [None]:
def compute_symptom_matching_score(patient_symptoms, disease_symptoms, verbose=False):

  # Encode the patient and disease symptoms
  disease_symptom_embeddings = torch.stack([encode_text(symptom) for symptom in disease_symptoms])

  average_matching_score = 0

  for target_symptom in patient_symptoms:

    target_embedding = encode_text(target_symptom)

    # Check the shapes of the embeddings
    #print(f"Target Embedding Shape: {target_embedding.shape}")

    # Calculate cosine similarities
    cosine_similarities = F.cosine_similarity(disease_symptom_embeddings.squeeze(), target_embedding)

    # Debugging information
    #print(f"Cosine Similarities: {cosine_similarities}")

    # Find the best match
    best_match_score = torch.max(cosine_similarities).numpy()
    #print(f"Best Match Score: {best_match_score}")

    # Find the best match index
    best_match_index = torch.argmax(cosine_similarities).item()
    #print(f"Best Match Index: {best_match_index}")

    best_match_symptom = disease_symptoms[best_match_index]
    #print(f"The symptom most similar to '{target_symptom}' is '{best_match_symptom}'.")

    if verbose:
      print(f"Disease Embedding Shape: {disease_symptom_embeddings.shape}")
      print(f"Target Embedding Shape: {target_embedding.shape}")
      print(f"Cosine Similarities: {cosine_similarities}")
      print(f"Best Match Score: {best_match_score}")
      print(f"Best Match Index: {best_match_index}")
      print(f"The symptom most similar to '{target_symptom}' is '{best_match_symptom}'. \n")

    average_matching_score += best_match_score

  average_matching_score /= len(patient_symptoms)

  if verbose: print(f"average_matching_score: {average_matching_score}")

  return average_matching_score

In [None]:
patient_symptoms = ['Anxiety and nervousness', 'Depression', 'Shortness of breath']
disease_symptoms = ['Anxiety and nervousness', 'Depression', 'Shortness of breath', 'Depressive or psychotic symptoms']

compute_symptom_matching_score(patient_symptoms, disease_symptoms, verbose=True)

In [None]:
def find_best_diagnosis(patient_symptoms, diseases, verbose=False):

  max_matching_score = -np.inf
  best_diagnosis = "None"

  for disease in diseases:
    disease_symptoms = diseases[disease]
    average_matching_score = compute_symptom_matching_score(patient_symptoms, disease_symptoms)
    if verbose:
      print(f"disease: {disease}")
      print(f"average_matching_score: {average_matching_score} \n")

    if average_matching_score > max_matching_score:
      max_matching_score = average_matching_score
      best_diagnosis = disease

  if verbose:
    print(f"best_diagnosis: {best_diagnosis}")
    print(f"max_matching_score: {max_matching_score}")

  return best_diagnosis

In [None]:
find_best_diagnosis(patient_symptoms, diseases, verbose=True)

### 3.2 Test end-to-end


In [1]:
!pip install transformers



In [2]:
import numpy as np
import torch
from transformers import BertModel, BertTokenizer
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

BERT_MODEL_ID = "medicalai/ClinicalBERT"
#BERT_MODEL_ID = "bert-base-uncased"

DISEASES = {
    "Panic disorder": ['Anxiety and nervousness', 'Depression', 'Shortness of breath', 'Depressive or psychotic symptoms', 'Sharp chest pain', 'Dizziness', 'Insomnia', 'Abnormal involuntary movements', 'Chest tightness', 'Palpitations', 'Irregular heartbeat', 'Breathing fast'],
    "Vocal cord polyp": ['Hoarse voice', 'Sore throat', 'Difficulty speaking', 'Cough', 'Nasal congestion', 'Throat swelling', 'Diminished hearing', 'Lump in throat', 'Throat feels tight', 'Difficulty in swallowing', 'Skin swelling', 'Retention of urine'],
    "Turner syndrome": ['Groin mass', 'Leg pain', 'Hip pain', 'Suprapubic pain', 'Blood in stool', 'Lack of growth', 'Diminished hearing', 'Depression', 'Emotional symptoms', 'Elbow weakness', 'Back weakness', 'Pus in sputum'],
    "Cryptorchidism": ['Symptoms of the scrotum and testes', 'Swelling of scrotum', 'Pain in testicles', 'Flatulence', 'Pus draining from ear', 'Jaundice', 'Mass in scrotum', 'Lack of growth', 'White discharge from eye', 'Irritable infant'],
    "Poisoning due to ethylene glycol": ['Abusing alcohol', 'Fainting', 'Hostile behavior', 'Drug abuse', 'Depressive or psychotic symptoms', 'Sharp abdominal pain', 'Feeling ill', 'Vomiting', 'Headache', 'Depression', 'Nausea', 'Diarrhea'],
    "Atrophic vaginitis": ['Vaginal itching', 'Vaginal dryness', 'Painful urination', 'Involuntary urination', 'Pain during intercourse', 'Frequent urination', 'Lower abdominal pain', 'Suprapubic pain', 'Vaginal discharge', 'Blood in urine', 'Hot flashes', 'Intermenstrual bleeding'],
    "Fracture of the hand": ['Hand or finger pain', 'Wrist pain', 'Hand or finger swelling', 'Arm pain', 'Wrist swelling', 'Arm stiffness or tightness', 'Arm swelling', 'Hand or finger stiffness or tightness', 'Wrist stiffness or tightness'],
    "Cellulitis or abscess of mouth": ['Lip swelling', 'Sore throat', 'Toothache', 'Abnormal appearing skin', 'Skin lesion', 'Difficulty in swallowing', 'Acne or pimples', 'Dry lips', 'Facial pain', 'Mouth ulcer', 'Throat swelling', 'Skin growth'],
}

print(f"Loading BERT model: {BERT_MODEL_ID}")

if BERT_MODEL_ID == "medicalai/ClinicalBERT":
  # Load pre-trained ClinicalBERT tokenizer and model
  tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
  model = AutoModel.from_pretrained("medicalai/ClinicalBERT")
elif BERT_MODEL_ID == "bert-base-uncased":
  # Load pre-trained model tokenizer and model
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  model = BertModel.from_pretrained('bert-base-uncased')
else:
  raise ValueError(f"Please input a valid BERT model ID {BERT_MODEL_ID}")

print(f"Loaded BERT model: {BERT_MODEL_ID}")

# Function to encode text to embeddings
def encode_text(text):
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=False)
    with torch.no_grad():
      outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1)

# Function to compute the average matching score between patient and disease symptoms
def compute_symptom_matching_score(patient_symptoms, disease_symptoms, verbose=False):

  print(f"patient_symptoms: {patient_symptoms}")
  print(f"disease_symptoms: {disease_symptoms}")

  average_matching_score = 0

  disease_symptom_embeddings = torch.stack([encode_text(symptom) for symptom in disease_symptoms])

  for target_symptom in patient_symptoms:

    with torch.no_grad():
        target_embedding = encode_text(target_symptom)
        cosine_similarities = F.cosine_similarity(disease_symptom_embeddings.squeeze(), target_embedding)

    # Find the best match
    best_match_score = torch.max(cosine_similarities).numpy()

    if verbose:
      print(f"Disease Embedding Shape: {disease_symptom_embeddings.shape}")
      print(f"Target Embedding Shape: {target_embedding.shape}")
      print(f"Cosine Similarities: {cosine_similarities}")
      print(f"Best Match Score: {best_match_score}")

    average_matching_score += best_match_score

  average_matching_score /= len(patient_symptoms)

  if verbose: print(f"average_matching_score: {average_matching_score}")

  return average_matching_score

# Function to find the best diagnosis to match patient symptoms
def find_best_diagnosis(patient_symptoms, verbose=False):

  print(f"patient_symptoms: {patient_symptoms}")

  max_matching_score = -np.inf
  best_diagnosis = "None"

  for disease in DISEASES:
    disease_symptoms = DISEASES[disease]
    average_matching_score = compute_symptom_matching_score(patient_symptoms, disease_symptoms, verbose=False)
    if verbose:
      print(f"disease: {disease}")
      print(f"average_matching_score: {average_matching_score} \n")

    if average_matching_score > max_matching_score:
      max_matching_score = average_matching_score
      best_diagnosis = disease

  if verbose:
    print(f"best_diagnosis: {best_diagnosis}")
    print(f"max_matching_score: {max_matching_score} \n")

  return best_diagnosis, max_matching_score

Loading BERT model: medicalai/ClinicalBERT


tokenizer_config.json:   0%|          | 0.00/62.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/466 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/542M [00:00<?, ?B/s]

Loaded BERT model: medicalai/ClinicalBERT


In [3]:
patient_symptoms = ['Hoarse voice', 'Sore throat', 'Difficulty speaking']
find_best_diagnosis(patient_symptoms, verbose=True)

patient_symptoms: ['Hoarse voice', 'Sore throat', 'Difficulty speaking']
patient_symptoms: ['Hoarse voice', 'Sore throat', 'Difficulty speaking']
disease_symptoms: ['Anxiety and nervousness', 'Depression', 'Shortness of breath', 'Depressive or psychotic symptoms', 'Sharp chest pain', 'Dizziness', 'Insomnia', 'Abnormal involuntary movements', 'Chest tightness', 'Palpitations', 'Irregular heartbeat', 'Breathing fast']
disease: Panic disorder
average_matching_score: 0.6685559352238973 

patient_symptoms: ['Hoarse voice', 'Sore throat', 'Difficulty speaking']
disease_symptoms: ['Hoarse voice', 'Sore throat', 'Difficulty speaking', 'Cough', 'Nasal congestion', 'Throat swelling', 'Diminished hearing', 'Lump in throat', 'Throat feels tight', 'Difficulty in swallowing', 'Skin swelling', 'Retention of urine']
disease: Vocal cord polyp
average_matching_score: 1.0 

patient_symptoms: ['Hoarse voice', 'Sore throat', 'Difficulty speaking']
disease_symptoms: ['Groin mass', 'Leg pain', 'Hip pain', 'S

('Vocal cord polyp', 1.0)