In [1]:
import os
import json
import numpy as np
import string
import time

import nltk 
nltk.download('punkt')
nltk.download('stopwords')
from nltk import skipgrams, ngrams
from nltk.corpus import stopwords 
# reason for using snowball: https://stackoverflow.com/questions/10554052/what-are-the-major-differences-and-benefits-of-porter-and-lancaster-stemming-alg
from nltk.stem.snowball import SnowballStemmer
from nltk.tokenize import word_tokenize

from IPython.display import clear_output

[nltk_data] Downloading package punkt to /home/gustaw/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/gustaw/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
questions_data_path = '../data/medqa_data/questions/US_qbank.jsonl'
textbooks_data_dir = '../data/medqa_data/textbooks/'

In [3]:
all_questions_data = []
with open(questions_data_path, 'r') as file:
    for line in file:
        all_questions_data.append(json.loads(line))

Questions which have valid answer connected to the evidence (from the paper)

In [4]:
chlamydia_question = '''A 27-year-old male presents to urgent care complaining of pain with urination. He reports that the pain started 3 days ago. He has never experienced these symptoms before. He denies
gross hematuria or pelvic pain. He is sexually active with his girlfriend, and they consistently use condoms. When asked about recent travel, he admits to recently returning from a
boys’ trip” in Cancun where he had unprotected sex 1 night with a girl he met at a bar. The patients medical history includes type I diabetes that is controlled with an insulin pump.
His mother has rheumatoid arthritis. The patients temperature is 99 F (37.2 C), blood pressure is 112/74 mmHg, and pulse is 81/min. On physical examination, there are no lesions of
the penis or other body rashes. No costovertebral tenderness is appreciated. A urinalysis reveals no blood, glucose, ketones, or proteins but is positive for leukocyte esterase. A urine
microscopic evaluation shows a moderate number of white blood cells but no casts or crystals. A urine culture is negative. Which of the following is the most likely cause for the
patient’s symptoms?'''

sod1_question = '''A 57-year-old man presents to his primary care physician with a 2-month history of right upper and lower extremity weakness. He noticed the weakness when he started falling far
more frequently while running errands. Since then, he has had increasing difficulty with walking and lifting objects. His past medical history is significant only for well-controlled
hypertension, but he says that some members of his family have had musculoskeletal problems. His right upper extremity shows forearm atrophy and depressed reflexes while his right
lower extremity is hypertonic with a positive Babinski sign. Which of the following is most likely associated with the cause of this patients symptoms?
'''

In [5]:
corpus = {}
for textbook_name in os.listdir(textbooks_data_dir):
    textbook_path = textbooks_data_dir + '/' + textbook_name
    with open(textbook_path, 'r') as textbook_file:
        textbook_content = textbook_file.read()
        corpus[textbook_name] = textbook_content

## Pointwise Mutual Information

In [6]:
stop_words = stopwords.words('english')
snowball_stemmer = SnowballStemmer(language='english') 
# do not remove the '-' and '/'
custom_string_punctuation = string.punctuation.replace('-','').replace('/','')
punctuation = str.maketrans('', '', custom_string_punctuation)

In [7]:
def preprocess_content(content):
    # lowercase and create tokens
    tokens = word_tokenize(content.lower())
    # remove stepwords
    tokens = [x for x in tokens if x not in stop_words]
    # stemming
    tokens = [snowball_stemmer.stem(x) for x in tokens]
    return ' '.join(tokens)

In [8]:
def preprocess_corpus(corpus):
    counter = 0
    for name, content in corpus.items():
        # TODO: removal of non-medical terms using MetaMap
        corpus[name] = preprocess_content(content).translate(punctuation).replace('“','').replace('’','')
        counter += 1
        clear_output(wait=True)
        print(f'Processed textbook {name} ({counter}/{len(corpus.items())})')


In [9]:
def preprocess_questions(questions):
    stop_words = stopwords.words('english')
    snowball_stemmer = SnowballStemmer(language='english') 
    counter = 0
    
    for question in questions:
        question['question'] = preprocess_content(question['question']).translate(punctuation).replace('“','').replace('’','')
        for option, value in question['options'].items():
            question['options'][option] = preprocess_content(value).translate(punctuation).replace('“','').replace('’','')
        counter+=1
        clear_output(wait=True)
        print(f'Processed question ({counter}/{len(questions)})')

In [10]:
preprocess_corpus(corpus)
preprocess_questions(all_questions_data)

Processed question (14369/14369)


### Prepare data

In [11]:
corpus_joined = ' '.join([x for x in corpus.values()])

corpus_tokens = word_tokenize(corpus_joined)

corpus_unigrams = ngrams(corpus_joined.split(), 1)
corpus_unigrams_freq = nltk.FreqDist(corpus_unigrams)

corpus_bigrams = ngrams(corpus_joined.split(), 2)
corpus_bigrams_freq = nltk.FreqDist(corpus_bigrams)

corpus_trigrams = ngrams(corpus_joined.split(), 2)
corpus_trigrams_freq = nltk.FreqDist(corpus_trigrams)

selected_questions_data = [all_questions_data[13182], all_questions_data[13984]]

In [12]:
# # skipgrams contain the ngrams

# sentence = 'this is a foo bar sentences and i want to ngramize it'

# n = 1
# bigrams = ngrams(sentence.split(), n)
# one_skip_bigrams = skipgrams(sentence.split(), 2, 2)
# # for grams in one_skip_bigrams:
# #   print( grams)

In [13]:
# test_question = selected_questions_data[0]['question']
# bigrams = ngrams(test_question.split(), 2)

# fdist = nltk.FreqDist(bigrams)

# # for x in fdist.items():
# #     print(x)
    

In [14]:
def calculate_joint_probability(combined_corpus_tokens, ngram1, ngram2, window):
    ngram1_joined = ' '.join(ngram1)
    ngram2_joined = ' '.join(ngram2)  
    
    joint_occurences = 0

    for i in range(0, len(combined_corpus_tokens) - window + 1):
        corpus_window = ' '.join(combined_corpus_tokens[i:i+window])
        if ngram1_joined in corpus_window and ngram2_joined in corpus_window:
            joint_occurences += 1
        
    return joint_occurences/len(combined_corpus_tokens)
        
    

def calculate_average_pmi_per_answer(
        combined_corpus_tokens,
        corpus_unigrams_freq, 
        corpus_bigrams_freq, 
        corpus_trigrams_freq, 
        question, 
        answer_option, 
        window):
    combined_corpus_tokens_len = len(combined_corpus_tokens)
    question_unigrams = ngrams(question.split(), 1)
    question_bigrams = ngrams(question.split(), 2)  
    question_trigrams = ngrams(question.split(), 3)    
    answer_option_unigrams = ngrams(answer_option.split(), 1)
    answer_option_bigrams = ngrams(answer_option.split(), 2)
    
    pmi_score = []
    start_time = time.time()
    
    for answer_bigram in answer_option_bigrams:
        p_y = corpus_bigrams_freq[answer_bigram] / combined_corpus_tokens_len
#         print(f'p_y={p_y}')
        for question_bigram in list(question_bigrams):
            counter = 0
            p_x = corpus_bigrams_freq[question_bigram] / combined_corpus_tokens_len
#             print(f'p_x={p_x}')
            p_x_y = calculate_joint_probability(combined_corpus_tokens, question_bigram, answer_bigram, 10)
#             print(f'p_x_y={p_x_y}')
            
            if p_x != 0 and p_y != 0:
                score = p_x_y / p_x / p_y
                if score != 0:
                    print('non zero')
                    pmi_score.append(np.log(score))
            counter += 1
            print(f'Question bigram {question_bigram} completed. {counter}/{len(list(question_bigrams))} in ({time.time() - start_time}s)')
    return np.average(pmi_score)

In [15]:
test_question = selected_questions_data[0]
A_score = calculate_average_pmi_per_answer(
    combined_corpus_tokens=corpus_tokens,
    corpus_unigrams_freq=corpus_unigrams_freq, 
    corpus_bigrams_freq=corpus_bigrams_freq, 
    corpus_trigrams_freq=corpus_trigrams_freq, 
    question=test_question['question'], 
    answer_option=test_question['options']['A'], 
    window=10)

print(A_score)

Question bigram ('27-year-old', 'male') completed. 1/0 in (2.9936442375183105s)
Question bigram ('male', 'present') completed. 1/0 in (6.068277597427368s)
Question bigram ('present', 'urgent') completed. 1/0 in (9.087483644485474s)
Question bigram ('urgent', 'care') completed. 1/0 in (12.277844190597534s)
Question bigram ('care', 'complain') completed. 1/0 in (15.415304899215698s)
Question bigram ('complain', 'pain') completed. 1/0 in (18.494123697280884s)
Question bigram ('pain', 'urin') completed. 1/0 in (21.674288511276245s)
Question bigram ('urin', 'report') completed. 1/0 in (24.836106061935425s)
Question bigram ('report', 'pain') completed. 1/0 in (27.985004425048828s)
Question bigram ('pain', 'start') completed. 1/0 in (31.25678515434265s)
Question bigram ('start', '3') completed. 1/0 in (34.59342813491821s)
Question bigram ('3', 'day') completed. 1/0 in (37.58857488632202s)
Question bigram ('day', 'ago') completed. 1/0 in (40.64483141899109s)
Question bigram ('ago', 'never') co

In [None]:
# test_corpus = '''Lorem ipsum dolor sit amet consectetur adipiscing elit Integer a lobortis nisl eget suscipit justo Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae Lorem ipsum dolor sit amet consectetur adipiscing elit Integer a lobortis nisl eget suscipit justo Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae Lorem ipsum dolor sit amet consectetur adipiscing elit Integer a lobortis nisl eget suscipit justo Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae'''
# test_corpus_tokens = word_tokenize(test_corpus)

# question = 'dolor sit'
# answer = 'adipiscing elit'

# answer_option_bigrams = ngrams(answer.split(), 2)
# # print(list(answer_option_bigrams))
# for answer in answer_option_bigrams:
#     print(answer)

# # ngram1 = ('dolor', 'sit')
# # ngram2 = ('adipiscing', 'elit')
# # print(calculate_average_pmi_per_answer(test_corpus_tokens, 'Lorem ipsum dolor sit amet', 'adipiscing elit', 10))


In [None]:
import nltk
from nltk.collocations import *
from nltk.tokenize import word_tokenize

text = "this is a foo bar bar black sheep  foo bar bar black sheep foo bar bar black  sheep shep bar bar black sentence"

trigram_measures = nltk.collocations.TrigramAssocMeasures()
finder = TrigramCollocationFinder.from_words(word_tokenize(text))

for i in finder.score_ngrams(trigram_measures.pmi):
    print(i)

In [None]:
bigrams_corpus = ngrams(corpus_joined.split(), 2)

In [None]:
bigrams_freq = nltk.FreqDist(bigrams_corpus)

In [None]:
x = ngrams(corpus_joined.split(), 1)