In [1]:
# import datasets and clean them,
# then aggregating symptoms by diseases

import json
from typing import Any

import pandas as pd

data = pd.read_csv('datasets/symbipredict_2022.csv')


def disease_by_symptoms(data: pd.DataFrame):
  diseases = data['prognosis'].unique()
  disease_symptoms: list[dict[str, Any]] = []

  for disease in diseases:
    subset = data[data['prognosis'] == disease]
    # extract symptoms that are present in the disease
    # by checking the symptoms columns that have a value of 1
    symptoms = subset.drop(columns=['prognosis']).sum(axis=0)
    relevant_symptoms = symptoms[symptoms > 0].index.tolist()
    disease_symptoms.append(
      {
        'name': disease.strip(),
        'slug': disease.strip().lower().replace(' ', '-'),
        'symptoms': [
          {
            'name': symptom.replace('_', ' ').capitalize(),
            'slug': symptom.replace(' ', '').replace('_', '-'),
          }
          for symptom in relevant_symptoms
        ],
      }
    )

  # and then return the dictionary with the disease as the key
  return disease_symptoms


diseases_json = disease_by_symptoms(data)

# write diseases_json to products/symbipredict_2022_cleaned.json
with open('products/symbipredict_2022_cleaned.json', 'w') as f:
  json.dump(diseases_json, f, indent=2)

print(f'Total diseases after cleaning: {len(diseases_json)}')
for data in diseases_json[:10]:
  print(data)

Total diseases after cleaning: 41
{'name': 'Fungal Infection', 'slug': 'fungal-infection', 'symptoms': [{'name': 'Itching', 'slug': 'itching'}, {'name': 'Skin rash', 'slug': 'skin-rash'}, {'name': 'Nodal skin eruptions', 'slug': 'nodal-skin-eruptions'}, {'name': 'Dischromic patches', 'slug': 'dischromic-patches'}]}
{'name': 'Allergy', 'slug': 'allergy', 'symptoms': [{'name': 'Continuous sneezing', 'slug': 'continuous-sneezing'}, {'name': 'Shivering', 'slug': 'shivering'}, {'name': 'Chills', 'slug': 'chills'}, {'name': 'Watering from eyes', 'slug': 'watering-from-eyes'}]}
{'name': 'GERD', 'slug': 'gerd', 'symptoms': [{'name': 'Stomach pain', 'slug': 'stomach-pain'}, {'name': 'Acidity', 'slug': 'acidity'}, {'name': 'Ulcers on tongue', 'slug': 'ulcers-on-tongue'}, {'name': 'Vomiting', 'slug': 'vomiting'}, {'name': 'Cough', 'slug': 'cough'}, {'name': 'Chest pain', 'slug': 'chest-pain'}]}
{'name': 'Chronic Cholestasis', 'slug': 'chronic-cholestasis', 'symptoms': [{'name': 'Itching', 'slug':

In [2]:
# setup the llm model to generate the disease description & vector embedding
# import gemini api key from .env file, and setup the generative model

import os

import google.generativeai as genai
from dotenv import load_dotenv

load_dotenv()

GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
genai.configure(api_key=GEMINI_API_KEY)

txt_gen_model = genai.GenerativeModel(
  'gemini-1.5-pro',
  generation_config={
    'temperature': 0.6,
    'max_output_tokens': 2048,
    'response_mime_type': 'text/plain',
  },
)


def embed(content: str):
  # default output embedding size 768
  return genai.embed_content(
    model='models/text-embedding-004',
    content=content,
  )


In [5]:
# generate disease description and vector embedding
# finalize the data and write it to products/diseases.json


import time


def generate_disease_description(disease: dict[str, Any]) -> str:
  name = disease['name']
  symptoms = ', '.join([symptom['name'] for symptom in disease['symptoms']])
  prompt = f"""Write a brief description (max 100-150 words) of the disease "{name}" and its symptoms: {symptoms}"""

  try:
    response = txt_gen_model.generate_content(prompt)
    return response.text.strip()
  except Exception as e:
    print(f'Error generating description for disease: {name}')
    print(e)
    raise e


def disease_embedding_content(data: dict[str, Any]) -> str:
  name = data.get('name', '')
  symptoms = ', '.join(
    [symptom['name'] for symptom in data.get('symptoms', [])],
  )
  description = data.get('description', '').strip()
  content = f"""
    Condition: {name}
    Symptoms: {symptoms}
    Description: {description}
    """

  return content.strip()


def generate_vector_embedding(data: dict[str, Any]):
  content = disease_embedding_content(data)
  response = embed(content)
  return response['embedding']


# generate description and embedding for each disease
def generate_disease_data(disease: dict[str, Any]):
  disease['description'] = generate_disease_description(disease)
  disease['embedding'] = generate_vector_embedding(disease)
  return disease


def batched_generate_disease_data(
  diseases: list[dict[str, Any]],
  batch_size: int = 5,
  delay: int = 10,
):
  for idx in range(0, len(diseases), batch_size):
    batch = diseases[idx : idx + batch_size]
    for disease in batch:
      yield generate_disease_data(disease)
    if idx + batch_size < len(diseases):
      print(f'Waiting for {delay} seconds before next batch')
      time.sleep(delay)


# stream write the data as stream to products/diseases.json
if not os.path.exists('products/diseases.json'):
  with open('products/diseases.json', 'w') as f:
    f.write('[')
    for idx, disease in enumerate(batched_generate_disease_data(diseases_json)):
      if idx < len(diseases_json) - 1:
        json.dump(disease, f)
        f.write(',\n')
      else:
        json.dump(disease, f)
    f.write(']')

print('Disease data generation completed')

with open('products/diseases.json', 'r') as f:
  diseases_data = json.load(f)
  for data in diseases_data[:10]:
    print(
      {
        'name': data['name'],
        'description': data['description'][:100],
        'embedding': data['embedding'][:10],
      }
    )

print(
  f'Total diseases after generating description and embedding: {len(diseases_data)}'
)


Disease data generation completed
{'name': 'Fungal Infection', 'description': 'Fungal infections, also known as mycoses, are caused by microscopic fungi that thrive in warm, moist', 'embedding': [-0.051693015, -0.012622708, -0.025434237, -0.09953168, -0.034006976, 0.03242588, 0.0334174, -0.0689982, -0.016405754, -0.028993173]}
{'name': 'Allergy', 'description': 'An allergy is an immune system overreaction to normally harmless substances called allergens.  Expos', 'embedding': [0.018938275, 0.031708326, -0.043788694, -0.020188075, 0.024716083, 0.027433831, -0.029221075, -0.028782753, -0.0045079864, 0.024936184]}
{'name': 'GERD', 'description': 'Gastroesophageal reflux disease (GERD) is a chronic digestive disorder where stomach acid frequently', 'embedding': [0.0009153394, 0.04215058, -0.04263623, -0.053892054, 0.005846108, 0.046177518, -0.03259433, -0.04128031, -0.0019446728, 0.0713995]}
{'name': 'Chronic Cholestasis', 'description': 'Chronic cholestasis is a long-term liver condition 