#Setup


In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
import os
import json
import pickle
import numpy as np
import csv
from tqdm import tqdm
from collections import defaultdict

In [3]:
base = "/content/gdrive/MyDrive/Spring 2022/CS 263/Final"

#Preprocess the knowledge graph

In [None]:
relation_types = set()

def parse_kg():
  relas_lst =  []
  ptr_to_pref_name = {}
  ptr_to_name = defaultdict(list)
  name_to_ptr = {}
  # read entities
  with open(f"{base}/entities.csv") as f:
    reader = csv.reader(f)
    next(reader)
    for row in reader:
      entity_id, entity, entity_type = row[0:3]
      if entity_id not in ptr_to_pref_name:
        ptr_to_pref_name[entity_id] = entity
      ptr_to_name[entity_id].append(entity)
      name_to_ptr[entity] = entity_id
  # read relations
  with open(f"{base}/relations.csv") as f:
    reader = csv.reader(f)
    next(reader)
    for row in reader:
      head, tail, rtype = row[0:3]
      if rtype not in relation_types:
        relation_types.add(rtype)

      # eliminate relations with invalid entites
      skip = False
      if head not in ptr_to_name:
        skip = True
        print(f"head {head} no id match")
      if tail not in ptr_to_name:
        skip = True
        print(f"tail {tail} no id match")
      if skip: continue
      
      relas_lst.append((head, tail, rtype))
  # remove duplicate names
  ptr_to_name = {id: list(set(names)) for (id, names) in ptr_to_name.items()}

  return (relas_lst, ptr_to_pref_name, ptr_to_name, name_to_ptr)

relas_lst, ptr_to_pref_name, ptr_to_name, name_to_ptr = parse_kg()
relation_types = list(relation_types)

tail fms-like no id match


In [None]:
ptr_lst, names_lst = [], []
for key, val in ptr_to_pref_name.items():
    ptr_lst.append(key)
    names_lst.append(val)

In [None]:
neph_root = f"{base}/data/neph"
os.system(f'mkdir -p "{neph_root}"')

with open(f"{neph_root}/vocab.txt", "w") as fout:
    for name in names_lst:
        print (name, file=fout)

with open(f"{neph_root}/ptrs.txt", "w") as fout:
    for ptr in ptr_lst:
        print (ptr, file=fout)

In [None]:
id2concept = ptr_lst

In [None]:
import networkx as nx

def construct_graph():
    concept2id = {w: i for i, w in enumerate(id2concept)}
    id2relation = relation_types
    relation2id = {r: i for i, r in enumerate(id2relation)}
    graph = nx.MultiDiGraph()
    attrs = set()
    for relation in relas_lst:
        subj = concept2id[relation[0]]
        obj = concept2id[relation[1]]
        rel = relation2id[relation[2]]
        weight = 1.
        graph.add_edge(subj, obj, rel=rel, weight=weight)
        attrs.add((subj, obj, rel))
        graph.add_edge(obj, subj, rel=rel + len(relation2id), weight=weight)
        attrs.add((obj, subj, rel + len(relation2id)))
    output_path = f"{base}/data/neph/neph.graph"
    nx.write_gpickle(graph, output_path)
    return concept2id, id2relation, relation2id, graph

concept2id, id2relation, relation2id, KG = construct_graph()

#Preprocess QA into statements

In [71]:
import json

with open(f"{base}/qa.json") as f:
  data = json.load(f)

name = "neph"

In [72]:
import re

ALPHA = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
alpha = "abcdefghijklmnopqrstuvwxyz"
ALPHA_SIZE = 26
QBODY = re.compile(r"(?:[\d]+ *\.)(.*)")

def parse(text):
    prev = text
    ret = []
    for i in range(ALPHA_SIZE):
      # curr_split = re.compile(r"(.*)(?:[ \n…])+(?:Answer )?(?:" + a + r"\. )(.*)")
      curr_split = re.compile(r"(.*)(?:[ \n…\u2028\xa0])+(?:Answer )?(?:[" + alpha[i] + ALPHA[i] + r"]\.)(.*)")
      split = curr_split.match(prev)
      if split is None:
        ret.append(prev)
        break
      ret.append(split.group(1))
      prev = split.group(2)
    qbody = QBODY.match(ret[0])
    assert qbody is not None, "invalid question body" # checks for "<NUM>. ..." pattern
    ret[0] = qbody.group(1)
    return ret

# specify max number of answer choices we allow
MAX_CHOICES = 5

def preprocess(i, qa):
    global max_choices
    full_text = ' '.join(qa["body"])
    parsed = parse(full_text)
    assert len(parsed) > 1, "could not find answer choices" # if ret == 1, then only the question body is found
    assert len(parsed) - 1 <= MAX_CHOICES, f"more than {MAX_CHOICES} answer choices" # limit to 5 answer choices
    id = f"{name}-{i:05d}"
    answerKey = qa["answer"]
    assert answerKey <= ALPHA[len(parsed) - 2], f"correct answer not in choices (answer {answerKey} but max choice {ALPHA[len(parsed) - 2]})"
    stem      = parsed[0].strip()
    choices   = [{"label": ALPHA[i-1], "text": parsed[i].strip()} for i in range(1, len(parsed))]
    # pad to max # choices
    for i in range(len(choices), MAX_CHOICES): choices.append({"label": ALPHA[i], "text": ""})
    stmts     = [{"statement": stem +" "+ c["text"]} for c in choices]
    ex_obj    = {"id": id, 
                  "question": {"stem": stem, "choices": choices}, 
                  "answerKey": answerKey, 
                  "statements": stmts,
                  # "original": full_text
                }
      
    return ex_obj
  
failed = []
examples = []
for i, qa in enumerate(data[:]):
  try:
    examples.append(preprocess(i, qa))
  except Exception as e: # store failed inputs for inspection
    # print(f"{i} failed (failed[{len(failed)}])\n\t{e}\n\t{' '.join(qa['body'])}")
    failed.append({"qa": qa, "index": i, "reason": e})

In [73]:
print(f"{len(data)} original, {len(examples)} success, {len(failed)} fail")

1529 original, 1428 success, 101 fail


In [74]:
examples[0:5]

[{'answerKey': 'B',
  'id': 'neph-00000',
  'question': {'choices': [{'label': 'A', 'text': 'UPEP and SPEP'},
    {'label': 'B', 'text': 'SPEP, Serum IFE, and Serum FLC assay'},
    {'label': 'C', 'text': 'UPEP, SPEP, and Serum IFE'},
    {'label': 'D', 'text': 'UPEP, SPEP, and Serum FLC Assay.'},
    {'label': 'E', 'text': ''}],
   'stem': 'A 73-year-old female with a past medical history of IDDM T2, hypertension, and hyperlipidemia presented to the hospital with 1 month of neck and shoulder pain with associated paresthesia and weakness in the fingertips. The pain radiates down both arms. There is no history of recent fall or trauma. Current medications include: lantus 20 units QD, lisinopril 20 mg QD, and atorvastatin 40mg QD. She was also taking ibuprofen 600 mg 3-4 times daily for 3 weeks for the pain with no significant relief. Initial laboratory tests show a serum creatinine 1.48 mg/dL, calcium 14.1 mg/dL, and elevated protein gap of 6.2. The remainder of her chemistry panel is n

In [75]:
nephqa_root = f"{base}/data/nephqa"
os.system(f'mkdir -p "{nephqa_root}/statement"')

with open(f"{nephqa_root}/statement/{name}.statement.jsonl", 'w') as fout:
  for dic in examples:
      print(json.dumps(dic), file=fout)