## Filtering questions

The goal of this notebook is to filter the questions based on whether the majority of the answer options (at least 3) match any CUI corresponding to the disease TUI.

The disease CUIs and TUIs are fetched from the `umls_terms.csv` file

In [1]:
import json
import math
import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
import pandas as pd
import pickle
import re
# in order to ignore the UserWarning: THis pattern has match groups
import warnings
warnings.filterwarnings("ignore", 'This pattern has match groups')

from itertools import groupby
from nltk import word_tokenize
from tqdm import tqdm

In [3]:
umls_terms_path = '../..//data/umls_terms.csv'
q_train_path = '../../data/medqa/questions/4_options/train.jsonl'
q_val_path = '../../data/medqa/questions/4_options/dev.jsonl'
q_test_path = '../../data/medqa/questions/4_options/test.jsonl'

In [5]:
umls_terms = pd.read_csv(umls_terms_path)

# lines 56-67 from the SemGroups_2018.txt
disorders_tui = {
    "T020": "Acquired Abnormality",
    "T190": "Anatomical Abnormality",
    "T049": "Cell or Molecular Dysfunction",
    "T019": "Congenital Abnormality",
    "T047": "Disease or Syndrome",
    "T050": "Experimental Model of Disease",
    "T033": "Finding",
    "T037": "Injury or Poisoning",
    "T048": "Mental or Behavioral Dysfunction",
    "T191": "Neoplastic Process",
    "T046": "Pathologic Function",
    "T184": "Sign or Symptom"
}

print(f'Number of records in umls_terms before TUI filtering: {umls_terms.shape[0]}')

# preserve only data related to the diseases' TUIs
umls_terms = umls_terms.loc[umls_terms['TUI'].isin(disorders_tui.keys())].reset_index(drop=True)
print(f'Number of records in umls_terms after TUI filtering: {umls_terms.shape[0]}')

# one row does not have a STR value, hence drop it
umls_terms = umls_terms.dropna(subset=['STR'])
umls_terms['STR'] = umls_terms['STR'].apply(lambda x: x.lower())

umls_cuis = umls_terms['CUI'].unique()

Number of records in umls_terms before TUI filtering: 8851980
Number of records in umls_terms after TUI filtering: 1775855


In [6]:
q_train, q_val, q_test = [], [], []

with open(q_train_path, 'r') as file:
    for line in file:
        q_train.append(json.loads(line))

with open(q_val_path, 'r') as file:
    for line in file:
        q_val.append(json.loads(line))
        
with open(q_test_path, 'r') as file:
    for line in file:
        q_test.append(json.loads(line))
        
num_all = len(q_train) + len(q_val) + len(q_test)
        
print(f"Num of all questions: {num_all}: {len(q_train)} + {len(q_val)} + {len(q_test)}")

Num of all questions: 12723: 10178 + 1272 + 1273


In [7]:
def create_processor_chunks(data, number_of_processes):
    chunk_size = math.ceil(math.ceil(len(data) / number_of_processes))

    for i in range(0, len(data), chunk_size):
        yield data[i:i + chunk_size]

In [8]:
def filter_questions(question_list, all_valid, all_invalid):
    for q_data in tqdm(question_list):    
        counter = 0
        for option in q_data['options'].values():
            try:
                matching_cuis = umls_terms.loc[umls_terms['STR'].str.contains(option.lower(), na=False)]
                if not matching_cuis.empty:
                    counter += 1
            except:
                counter = 0
                break
        
        
        if counter > 2:
            all_valid.append(q_data)
        else:
            all_invalid.append(q_data)

In [9]:
def multiprocess_exec(question_list, num_of_processors):
    manager = multiprocessing.Manager()
    all_valid = manager.list()
    all_invalid = manager.list()
    jobs = []
    
    data = list(create_processor_chunks(question_list, num_of_processors))

    for i in range(num_of_processors):
        p = multiprocessing.Process(target=filter_questions, args=(data[i], all_valid, all_invalid))
        jobs.append(p)
        p.start()

    for proc in jobs:
        proc.join()
        
    return list(all_valid), list(all_invalid)

In [None]:
q_train_valid, q_train_invalid = multiprocess_exec(q_train, 6)

assert len(q_train_valid) + len(q_train_invalid) == len(q_train)
print("*** Train set ***")
print(f"Num of valid: {len(q_train_valid)}\t Num of invalid: {len(q_train_invalid)}")

questions_train_valid_path = "../../data/medqa/questions/4_options/[filtered]q_train_valid.json"
questions_train_invalid_path = "../../data/medqa/questions/4_options/[filtered]q_train_invalid.json"

with open(questions_train_valid_path, 'w') as file:
    json.dump(q_train_valid, file)
    
with open(questions_train_invalid_path, 'w') as file:
    json.dump(q_train_invalid, file)

100%|██████████| 1697/1697 [2:43:16<00:00,  5.77s/it]  
100%|██████████| 1697/1697 [2:44:12<00:00,  5.81s/it]


In [10]:
q_val_valid, q_val_invalid = multiprocess_exec(q_val, 6)

assert len(q_val_valid) + len(q_val_invalid) == len(q_val)
print("*** Val set ***")
print(f"Num of valid: {len(q_val_valid)}\t Num of invalid: {len(q_val_invalid)}")

questions_val_valid_path = "../../data/medqa/questions/4_options/[filtered]q_val_valid.json"
questions_val_invalid_path = "../../data/medqa/questions/4_options/[filtered]q_val_invalid.json"

with open(questions_val_valid_path, 'w') as file:
    json.dump(q_val_valid, file)
    
with open(questions_val_invalid_path, 'w') as file:
    json.dump(q_val_invalid, file)

100%|██████████| 255/255 [20:43<00:00,  4.88s/it]
100%|██████████| 255/255 [20:43<00:00,  4.88s/it]
100%|██████████| 255/255 [21:02<00:00,  4.95s/it]
100%|██████████| 255/255 [21:06<00:00,  4.97s/it]
100%|██████████| 252/252 [21:14<00:00,  5.06s/it]


In [29]:
q_test_valid, q_test_invalid = multiprocess_exec(q_test, 6)

assert len(q_test_valid) + len(q_test_invalid) == len(q_test)
print("*** Test set ***")
print(f"Num of valid: {len(q_test_valid)}\t Num of invalid: {len(q_test_invalid)}")

questions_test_valid_path = "../../data/medqa/questions/4_options/[filtered]q_test_valid.json"
questions_test_invalid_path = "../../data/medqa/questions/4_options/[filtered]q_test_invalid.json"

with open(questions_test_valid_path, 'w') as file:
    json.dump(q_test_valid, file)
    
with open(questions_test_invalid_path, 'w') as file:
    json.dump(q_test_invalid, file)

100%|██████████| 208/208 [20:27<00:00,  5.90s/it]
100%|██████████| 213/213 [20:39<00:00,  5.82s/it]
100%|██████████| 213/213 [20:41<00:00,  5.83s/it]
100%|██████████| 213/213 [20:43<00:00,  5.84s/it]
100%|██████████| 213/213 [20:45<00:00,  5.85s/it]
100%|██████████| 213/213 [21:14<00:00,  5.99s/it]
