## Setup Environment

In [None]:
!pip install scispacy
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.0/en_core_sci_lg-0.5.0.tar.gz

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.0/en_core_sci_lg-0.5.0.tar.gz
  Using cached https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.0/en_core_sci_lg-0.5.0.tar.gz (532.3 MB)


In [None]:
import string
import random
import time
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import spacy
import scispacy
import numpy as np
import multiprocessing as mp
import xml.etree.ElementTree as ETree
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from urllib.request import urlopen
from sklearn.cluster import KMeans
from sklearn import metrics

### Setup Globals

In [None]:
# Randomize random seed
random.seed(time.time())

In [None]:
nlp = spacy.load("en_core_sci_lg")

## Dataset Processing

### Download Datasets

In [None]:
# Download QA from Medinfo 2019
medinfo_url_raw = "https://docs.google.com/spreadsheets/d/1m9IrJtLY57P9TrJ5sBvSlSqSsmBPJ4mHckew17jhJfc/edit#gid=292450826"
medinfo_url = medinfo_url_raw.replace('/edit#gid=', '/export?format=csv&gid=')
medInfo_raw = pd.read_csv(medinfo_url)

In [None]:
# Download QA from medQuAD
medQuAD_url_raw = "https://docs.google.com/spreadsheets/d/11uRnezy8m8JuI8GbrG8PaBLIcWAZVdWO9ER0vjLLBIM/edit#gid=1364087034"
medQuAD_url = medQuAD_url_raw.replace('/edit#gid=', '/export?format=csv&gid=')
medQuAD_raw = pd.read_csv(medQuAD_url)

In [None]:
# Download QA from LiveQA_MedicalTask_TREC2017 (XML) 
def parse_ref_question(data):
  answers = [f.find('ANSWER').text for f in data.find('ReferenceAnswers').findall('RefAnswer')] + \
              [f.find('ANSWER').text for f in data.find('ReferenceAnswers').findall('ReferenceAnswer') ] 
    
  return {
    'original_question': data.find('Original-Question').find('MESSAGE').text,
    'paraphrase_question': data.find('NIST-PARAPHRASE').text,
    'question_summary': data.find('NLM-Summary').text,
    'focuses': [f.text for f in data.find('ANNOTATIONS').findall('FOCUS')],
    'types': [f.text for f in data.find('ANNOTATIONS').findall('TYPE')],
    'answers': answers
  }

def parse_std_question(data):  
  if data.find('SUB-QUESTIONS'):
    sq = [f for f in data.find('SUB-QUESTIONS').findall('SUB-QUESTION')]

    answers = [f.find('ANSWER').text for f in sq[0].find('ANSWERS') if f.find('ANSWER')]
    focuses = [f.text for f in sq[0].find('ANNOTATIONS').findall('FOCUS')]
    types = [f.text for f in sq[0].find('ANNOTATIONS').findall('TYPE')]

  else:
    answers = [f.find('ANSWER').text for f in data.find('ANSWERS') if f.find('ANSWER')]
    focuses = [f.text for f in data.find('ANNOTATIONS').findall('FOCUS')]
    types = [f.text for f in data.find('ANNOTATIONS').findall('TYPE')]

  return {
    'original_question': data.find('MESSAGE').text,
    'paraphrase_question': None,
    'question_summary': None,
    'focuses': focuses,
    'types': types,
    'answers': answers
  }

liveQAMedTask_urls = [
      "https://raw.githubusercontent.com/abachaa/LiveQA_MedicalTask_TREC2017/master/TestDataset/TREC-2017-LiveQA-Medical-Test-Questions-w-summaries.xml",
      "https://raw.githubusercontent.com/abachaa/LiveQA_MedicalTask_TREC2017/master/TrainingDatasets/TREC-2017-LiveQA-Medical-Train-1.xml",
      "https://raw.githubusercontent.com/abachaa/LiveQA_MedicalTask_TREC2017/master/TrainingDatasets/TREC-2017-LiveQA-Medical-Train-2.xml",
]
liveQAMedTask_flattened_xml = []

for url in liveQAMedTask_urls:
  prstree = ETree.parse(urlopen(url))
  root = prstree.getroot()
    
  for question in root.iter('NLM-QUESTION'):

    q = parse_ref_question(question) if question.find('Original-Question') else parse_std_question(question)

    liveQAMedTask_flattened_xml.append(q)

liveQAMedTask_raw = pd.DataFrame(liveQAMedTask_flattened_xml)

In [None]:
liveQAMedTask_raw

550

### Standardize Datasets

In [None]:
def normalize(text):
    def _remove_punct(doc):
        return (t for t in doc if t.text not in string.punctuation)

    def _remove_stop_words(doc):
        return (t for t in doc if not t.is_stop)

    def _lemmatize(doc):
        return (t.lemma_ for t in doc)

    doc = nlp(text.lower())
    removed_punct = _remove_punct(doc)
    removed_stop_words = _remove_stop_words(removed_punct)

    return _lemmatize(removed_stop_words)


In [None]:
def clean_focus_text(focus_text):
  # Remove special characters and lowercase data
  clean = re.sub('[^a-zA-Z]',' ', focus_text.lower())

  # Remove measurement units that may of slipped by
  clean = clean.replace('mg', '')

  # Remove multiple spaces
  clean = re.sub('[ \t\n]+', ' ', clean)
  clean = re.sub(' +', ' ',clean)
  
  
  return clean.strip()

def clean_answer_text(answer_text):
  clean = re.sub('[ \t\n]+', ' ', answer_text)

  return clean.strip()

In [None]:
# MedInfo
def parse_medinfo_question_type(question_type):
  ## Usage
  if question_type in ['usage', 'usage/time', 'stopping/tapering']:
    return 'usage'

  ## Composition
  if question_type == 'ingredient':
    return 'composition'

  ## Alternative
  if question_type in ['alternatives', 'brand names', 'comparison']:
    return 'alternative'

  ## Appearance
  if question_type == 'appearance':
    return 'appearance'

  ## Dosage
  if question_type in ['dose', 'dose/potency']:
    return 'dosage'

  ## Interaction
  if question_type in ['interaction', 'contraindication']:
    return 'interaction'

  ## General
  if question_type in ['information', 'indication', 'action', 'action/time','pronounce name', 'availability', 'time/duration', 'action/effectiveness', 'storage and disposal', 'manufacturer']:
    return 'general'

  ## Side Effect
  if question_type in ['side effects', 'overdose', 'forget a dose', 'stopping/side effects']:
    return 'side effect'

def parse_medinfo_data(data):
  types = [ts for ts in (parse_medinfo_question_type(t.lower()) for t in [data['Question Type']]) if ts is not None]

  if len(types) == 0:
    return None

  cleaned_question = normalize(data['Question'])  
  words = [c for c in cleaned_question]

  return {
    'question_raw': data['Question'],
    'question_proc': ' '.join(words),
    'question_words': words,
    'types': types,
    'primary_type': types[0],
    'focus': clean_focus_text(' '.join(normalize(data['Focus (Drug)']))),
    'answer': clean_answer_text(data['Answer'])
  }

parsed_medinfo = []

for i, d in medInfo_raw.iterrows():
  parsed_medinfo.append(parse_medinfo_data(d))

medInfo = pd.DataFrame([x for x in parsed_medinfo if x != None])


In [None]:
medInfo

Unnamed: 0,question_raw,question_proc,question_words,types,primary_type,focus,answer
0,how does rivatigmine and otc sleep medicine in...,rivatigmine otc sleep medicine interact,"[rivatigmine, otc, sleep, medicine, interact]",[interaction],interaction,rivastigmine,tell your doctor and pharmacist what prescript...
1,how does valium affect the brain,valium affect brain,"[valium, affect, brain]",[general],general,valium,Diazepam is a benzodiazepine that exerts anxio...
2,what is morphine,morphine,[morphine],[general],general,morphine,Morphine is a pain medication of the opiate fa...
3,what are the milligrams for oxycodone e,milligram oxycodone e,"[milligram, oxycodone, e]",[dosage],dosage,oxycodone er,… 10 mg … 20 mg … 40 mg … 80 mg ...
4,81% aspirin contain resin and shellac in it. ?,81 aspirin contain resin shellac,"[81, aspirin, contain, resin, shellac]",[composition],composition,aspirin,Inactive Ingredients Ingredient Name
...,...,...,...,...,...,...,...
648,how soon does losartan affect blood pressure,soon losartan affect blood pressure,"[soon, losartan, affect, blood, pressure]",[general],general,losartan,The effect of losartan is substantially presen...
649,how do steroids effect the respiratory system,steroid effect respiratory system,"[steroid, effect, respiratory, system]",[general],general,steroid,Several efforts have been made to show the ben...
650,why am i so cold taking bystolic b p med,cold take bystolic b p me,"[cold, take, bystolic, b, p, me]",[side effect],side effect,bystolic,Feeling cold is found among people who take By...
651,pneumococcal vaccine how often,pneumococcal vaccine,"[pneumococcal, vaccine]",[usage],usage,pneumococcal vaccine,CDC recommends routine administration of pneum...


In [None]:
# MedQuAD
def parse_question_type_from_text(question_text):
  types = []

  ## Usage
  if ('used' in question_text) or ('use' in question_text):
    types.append('usage')

  ## Composition
  if ('ingredients' in question_text) or ('made of' in question_text):
    types.append('composition')

  ## Alternative
  if ('alternatives' in question_text) or ('substitute' in question_text):
    types.append('alternative')

  ## Appearance
  if ('smells like' in question_text) or ('looks like' in question_text) or ('taste like' in question_text) or ('color' in question_text):
    types.append('appearance')

  ## Dosage
  if ('dosage' in question_text) or ('dose' in question_text) :
    types.append('dosage')

  ## Interaction
  if ('interact' in question_text) or ('used with' in question_text):
    types.append('interaction')

  ## General
  if ('what is' in question_text) or ('should i know' in question_text) or ('what are' in question_text):
    types.append("general")

  ## Side Effect
  if ('adverse effects' in question_text) or ('side effects' in question_text) or ('reactions' in question_text) or ('cause' in question_text):
    types.append('side effect')

  return types

def parse_medquad_data(data):
  # Only split the first three new lines. This ensures the answer isnt split
  parsed = data.split("\n", 3)

  question = parsed[0].replace('Question: ', '').strip()
  link = parsed[1].replace('URL: ', '').strip()
  answer = parsed[2].replace('Answer: ', '').strip()

  ## Remove the also called portion of the question
  question = question.replace('Also called: ', '')

  types = parse_question_type_from_text(question.lower())
  
  if len(types) == 0:
    return None

  cleaned_question = normalize(question)
  words = [c for c in cleaned_question]
  tokens = nlp(question)
  
  if len(tokens.ents) > 0:
    focus = clean_focus_text(tokens.ents[0].text)
  else:
    return None

  return {
    'question_raw': question,
    'question_proc': ' '.join(words),
    'question_words': words,
    'types': types,
    'primary_type': types[0],
    'focus': focus,
    'answer': clean_answer_text(answer)
  }


parsed_medquad = [parse_medquad_data(d) for d in medQuAD_raw['Answer']]

medQuAD = pd.DataFrame([x for x in parsed_medquad if x != None])


In [None]:
medQuAD

Unnamed: 0,question_raw,question_proc,question_words,types,primary_type,focus,answer
0,What is (are) Polycystic ovary syndrome ? (Pol...,polycystic ovary syndrome polycystic ovary pol...,"[polycystic, ovary, syndrome, polycystic, ovar...",[general],general,polycystic ovary syndrome,Polycystic ovary syndrome is a condition in wh...
1,What causes Polycystic ovary syndrome ? (Polyc...,cause polycystic ovary syndrome polycystic ova...,"[cause, polycystic, ovary, syndrome, polycysti...","[usage, side effect]",usage,polycystic ovary syndrome,PCOS is linked to changes in hormone levels th...
2,What causes Noonan syndrome ?,cause noonan syndrome,"[cause, noonan, syndrome]","[usage, side effect]",usage,noonan syndrome,Noonan syndrome is linked to defects in severa...
3,What are the complications of Noonan syndrome ?,complication noonan syndrome,"[complication, noonan, syndrome]",[general],general,complications,- Buildup of fluid in tissues of body (lymphed...
4,What are the symptoms of Neurofibromatosis-Noo...,symptom neurofibromatosis-noonan syndrome nfns...,"[symptom, neurofibromatosis-noonan, syndrome, ...",[general],general,symptoms,What are the signs and symptoms of Neurofibrom...
...,...,...,...,...,...,...,...
1550,What should I do if I forget a dose of Glimepi...,forget dose glimepiride,"[forget, dose, glimepiride]",[dosage],dosage,dose,"Before you start to take glimepiride, ask you ..."
1551,What are the side effects or risks of Glimepir...,effect risk glimepiride,"[effect, risk, glimepiride]","[general, side effect]",general,side effects,This medication may cause changes in your bloo...
1552,What to do in case of emergency or overdose of...,case emergency overdose glimepiride,"[case, emergency, overdose, glimepiride]",[dosage],dosage,emergency,"In case of overdose, call your local poison co..."
1553,What other information should I know about Gli...,information know glimepiride,"[information, know, glimepiride]",[general],general,information,Keep all appointments with your doctor and the...


In [None]:
# LiveQAMed

def parse_liveqamed_question_type(question_type):

  ## Usage
  if question_type in ['usage', 'tapering']:
    return 'usage'

  ## Composition
  if question_type == 'ingredient':
    return 'composition'

  ## Alternative
  if question_type in ['alternative', 'comparison']:
    return 'alternative'

  ## Dosage
  if question_type in ['dosage']:
    return 'dosage'

  ## Interaction
  if question_type in ['effect', 'interaction', 'contraindication']:
    return 'interaction'

  ## General
  if question_type in ['treatment', 'prevention', 'diagnosis', 'cause', 'information', 'symptom', 'storage_disposal', 'action', 'susceptibility', 'indication', 'lifestyle_diet', 'prognosis', 'person_organization']:
    return 'general'

  ## Side Effect
  if question_type in ['side_effect', 'complication']:
    return 'side effect'

def parse_liveqamed_data(data):
  types = [ts for ts in (parse_liveqamed_question_type(t.lower()) for t in data['types']) if ts is not None]
  question = data['paraphrase_question'] or data['original_question']

  if len(types) == 0 or len(data['answers']) == 0:
    return None

  cleaned_question = normalize(question)
  tokens = nlp(question)
  words = [c for c in cleaned_question]
  focus = clean_focus_text(data['focuses'][0]) if len(data['focuses']) > 0 else None

  return {
    'question_raw': question,
    'question_proc': ' '.join(words),
    'question_words': words,
    'types': types,
    'primary_type': types[0],
    'focus': focus,
    'answer': clean_answer_text(data['answers'][0])
  }

parsed_liveqamed = []

for i, d in liveQAMedTask_raw.iterrows():
  parsed_liveqamed.append(parse_liveqamed_data(d))

liveQAMedTask = pd.DataFrame([x for x in parsed_liveqamed if x != None])

In [None]:
liveQAMedTask

Unnamed: 0,question_raw,question_proc,question_words,types,primary_type,focus,answer
0,What is the relationship between Noonan syndro...,relationship noonan syndrome polycystic renal ...,"[relationship, noonan, syndrome, polycystic, r...",[interaction],interaction,noonan syndrome,Noonan's syndrome is an eponymic designation t...
1,Do 5 mg. Zolmitriptan tabkets contain gluten?,5 mg zolmitriptan tabket contain gluten,"[5, mg, zolmitriptan, tabket, contain, gluten]",[composition],composition,zolmitriptan,Zolmitriptan tablets are available as 2.5 mg (...
2,Are amphetamine salts of 20 mg dosage gluten f...,amphetamine salt 20 mg dosage gluten free,"[amphetamine, salt, 20, mg, dosage, gluten, free]",[composition],composition,amphetamine salts,Active Ingredients Amphetamine Aspartate Amphe...
3,What are the treatments and precautions for VD...,treatment precaution vdrl positive syphilis pa...,"[treatment, precaution, vdrl, positive, syphil...","[general, general, general]",general,vdrl positive,"Syphilis If the RPR, VDRL, or TRUST tests are ..."
4,How much glucagon is in my GlucaGen kit?,glucagon glucagen kit,"[glucagon, glucagen, kit]",[composition],composition,glucagen hypokit,"GLUCAGEN glucagon hydrochloride injection, pow..."
...,...,...,...,...,...,...,...
99,To what extent does Effexor cause ED?,extent effexor cause ed,"[extent, effexor, cause, ed]",[side effect],side effect,effextor,The recommended starting dose for Effexor is 7...
100,How long has Non-aspirin NSAID been implicated...,long non-aspirin nsaid implicate erectile dysf...,"[long, non-aspirin, nsaid, implicate, erectile...",[side effect],side effect,nsaids,Non-aspirin NSAID use was associated with an i...
101,"What is aortic stenosis, and is there anything...",aortic stenosis,"[aortic, stenosis]","[general, general]",general,aeortic stenosis,The aorta is the main artery that carries bloo...
102,What can cause white cells ti uprate,cause white cell ti uprate,"[cause, white, cell, ti, uprate]",[general],general,white cells uprate,A high white blood cell count usually indicate...


In [None]:
# Combining all dataframe into one.
datasetdf = pd.concat([medInfo, medQuAD, liveQAMedTask],ignore_index = True)

In [None]:
len(datasetdf)

2312

## Feature Engineering - TF-IDF

In [None]:
comb_frame = datasetdf.focus.str.cat(" "+datasetdf.primary_type)

vectorizerTF = TfidfVectorizer()

X_focus_type = vectorizerTF.fit_transform(comb_frame)

# Clustering


In [None]:
model = KMeans(n_clusters=1000, init='k-means++', max_iter=100, n_init=15)
model.fit(X_focus_type)

KMeans(max_iter=100, n_clusters=1000, n_init=15)

In [None]:
def predict_cluster(input):  
    Y = vectorizerTF.transform(list(input))
    prediction = model.predict(Y)

    return prediction


In [None]:
datasetdf['cluster_key'] = datasetdf.focus.str.cat(" "+datasetdf.primary_type)

In [None]:
datasetdf['cluster_pred']= predict_cluster(datasetdf['cluster_key'])


In [None]:
def answer(str_input):
    q_types = parse_question_type_from_text(str_input)
    q_primary_type = q_types[0]
    input_tokens = nlp(str_input)

    if len(input_tokens.ents) > 0:
      focuses = [clean_focus_text(input_tokens.ents[0].text)]
    else:
      focuses = clean_focus_text(' '.join(normalize(str_input))).split(' ')

    temp_df = datasetdf[(datasetdf.primary_type == q_primary_type) & (datasetdf.focus.str.contains('|'.join(focuses)))].copy()
    temp_df['cluster_key'] = temp_df.focus.str.cat(" "+temp_df.primary_type)
    pred_input = list(temp_df.cluster_key)
    prediction_inp = predict_cluster(pred_input)
    prediction_inp = int(prediction_inp)    
    
    # Recommendation Logic is kept super-simple for current implementation. 
    temp_df = datasetdf[datasetdf['cluster_pred'] == prediction_inp].copy()

    temp_df['similarity'] = [input_tokens.similarity(nlp(qr)) for qr in temp_df['question_raw']]

    match = temp_df[temp_df['similarity'] == temp_df['similarity'].max()]

    if match.similarity > 0.8:
      return match.answer
    else:
      return "Sorry, we couldnt find any information to match your question"


In [None]:
answer('what is Polycystic ovary syndrome')

654    PCOS is linked to changes in hormone levels th...
Name: answer, dtype: object
