In [None]:
ROOT = '../parsing_package'
DATA_DIR = f'{ROOT}/data/'

In [3]:
import requests

import pandas as pd
import json
import os
import re
import pickle
import gc
import random
from pprint import pprint
from collections import defaultdict, Counter, namedtuple, OrderedDict
from enum import Enum
from dataclasses import dataclass
from typing import NamedTuple, Set, Optional
import math

import networkx as nx
# import obonet
import csv
# from goatools import obo_parser

import sys
sys.path.insert(0, ROOT)
from preprocessing.disease_data import MeshData
# from criteria import Criteria, cuid2concept
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sb
import networkx as nx
%matplotlib inline
plt.rcParams['figure.figsize'] = (8, 6.0) # set default size of plots

In [4]:
def get_clinical_trial_data(nctid):  
    # Base URL of the ClinicalTrials.gov API with the specified version  
    base_url = "https://clinicaltrials.gov/api/v2/studies"  
  
    # Construct the full URL for the API request using .format() method  
    request_url = "{}/{}".format(base_url, nctid)  
  
    try:  
        # Make the GET request to the API  
        response = requests.get(request_url)  
  
        # Check if the request was successful  
        if response.status_code == 200:  
            # Parse the JSON response  
            trial_data = response.json()  
            return trial_data  
        else:  
            # If the request was not successful, return the status code and error message  
            return {"error": "Failed to fetch data. Status code: {}, Message: {}".format(response.status_code, response.text)}  
      
    except Exception as e:  
        return {"error": str(e)}   

In [5]:
nctid = "NCT02370680"  # Replace with a valid NCT ID  
trial_data = get_clinical_trial_data(nctid)  

if "error" in trial_data:  
    pprint("Error fetching data: {}".format(trial_data['error'])) 

In [6]:
trial_data.keys()

dict_keys(['protocolSection', 'derivedSection', 'hasResults'])

In [7]:
def print_keys(d, parent_key=''):  
    """  
    This function prints all keys in a nested dictionary.  
      
    :param d: The dictionary to process  
    :param parent_key: Key of the parent dictionary (if any)  
    """  
    if isinstance(d, dict):  
        for key in d:  
            if parent_key:  
                full_key = '{}.{}'.format(parent_key, key)  
            else:  
                full_key = key  
            print(full_key)  
            print_keys(d[key], full_key)  # Recursive call to handle nested dictionaries 

In [8]:
print_keys(trial_data)

protocolSection
protocolSection.identificationModule
protocolSection.identificationModule.nctId
protocolSection.identificationModule.orgStudyIdInfo
protocolSection.identificationModule.orgStudyIdInfo.id
protocolSection.identificationModule.organization
protocolSection.identificationModule.organization.fullName
protocolSection.identificationModule.organization.class
protocolSection.identificationModule.briefTitle
protocolSection.identificationModule.officialTitle
protocolSection.identificationModule.acronym
protocolSection.statusModule
protocolSection.statusModule.statusVerifiedDate
protocolSection.statusModule.overallStatus
protocolSection.statusModule.expandedAccessInfo
protocolSection.statusModule.expandedAccessInfo.hasExpandedAccess
protocolSection.statusModule.startDateStruct
protocolSection.statusModule.startDateStruct.date
protocolSection.statusModule.primaryCompletionDateStruct
protocolSection.statusModule.primaryCompletionDateStruct.date
protocolSection.statusModule.primaryComp

In [9]:
nctid = "NCT02370680"  # Replace with a valid NCT ID  
trial_data = get_clinical_trial_data(nctid)  

if "error" in trial_data:  
    pprint("Error fetching data: {}".format(trial_data['error'])) 
    

attributes = {
    "nct_id": "protocolSection.identificationModule.nctId",
    "arm_group": "protocolSection.armsInterventionsModule.armGroups",
    "intervention": "protocolSection.armsInterventionsModule.interventions",
    "condition": "protocolSection.conditionsModule.conditions",
    "intervention_mesh_terms": "derivedSection.conditionBrowseModule.meshes",
    "event_groups": "resultsSection.adverseEventsModule.eventGroups",
    "primary_outcome": "protocolSection.outcomesModule.primaryOutcomes",
    "secondary_outcome": "protocolSection.outcomesModule.secondaryOutcomes",
    "eligibility_criteria": "protocolSection.eligibilityModule.eligibilityCriteria",
    'brief_summary': 'protocolSection.descriptionModule.briefSummary',
    'phase': 'protocolSection.designModule.phases',
    'enrollment': 'protocolSection.designModule.enrollmentInfo',
    'gender_sex': 'protocolSection.eligibilityModule.sex',
    'minimum_age': 'protocolSection.eligibilityModule.minimumAge',
    'maximum_age': 'protocolSection.eligibilityModule.maximumAge'
}

parsed_trial = {}
for attribute, path in attributes.items():
    val = trial_data
    for component in path.split("."):
        if component not in val:
            val = None
            break
        val = val[component]
    parsed_trial[attribute] = val
    
for arm_group in parsed_trial['arm_group']:
    arm_group['arm_group_label'] = arm_group.pop('label')
    
for intervention in parsed_trial['intervention']:
    intervention['intervention_type'] = intervention.pop('type').title()
    intervention['intervention_name'] = intervention.pop('name')
    if 'otherNames' in intervention:
        intervention['other_name'] = intervention.pop('otherNames')
    intervention['arm_group_label'] = intervention.pop('armGroupLabels')


parsed_trial['clinical_results'] = {
    "reported_events": {
        "group_list": {"group": parsed_trial.pop('event_groups')}
    }
} if parsed_trial.get('event_groups') is not None else {}
parsed_trial

{'nct_id': 'NCT02370680',
 'arm_group': [{'type': 'EXPERIMENTAL',
   'description': 'Aspirin run-in, followed with Durlaza™, one capsule QD (quaque die), for 14 ± 4 days and an in-patient visit',
   'interventionNames': ['Drug: Aspirin', 'Drug: Durlaza™'],
   'arm_group_label': 'Durlaza™, 1 capsule'},
  {'type': 'EXPERIMENTAL',
   'description': 'in a rollover with 10 subjects from the first arm, an aspirin run-in, followed by Durlaza™, two capsules QD, for 14 ± 4 days and an in-patient visit',
   'interventionNames': ['Drug: Aspirin', 'Drug: Durlaza™'],
   'arm_group_label': 'Durlaza™, 2 capsules'}],
 'intervention': [{'description': 'steady-state run-in prior to Durlaza treatment',
   'intervention_type': 'Drug',
   'intervention_name': 'Aspirin',
   'other_name': ['Bayer aspirin'],
   'arm_group_label': ['Durlaza™, 1 capsule', 'Durlaza™, 2 capsules']},
  {'description': 'comparison of different numbers of capsules',
   'intervention_type': 'Drug',
   'intervention_name': 'Durlaza™',

## Medex Input Create

In [10]:
from data_parsers.external_tools.medex import medex_input
from data_parsers.external_tools import medex
import tempfile
import json
import subprocess
import shlex

[nltk_data] Downloading package stopwords to
[nltk_data]     /dfs/scratch1/myasu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /dfs/scratch1/myasu/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /dfs/scratch1/myasu/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [11]:
result = {}

classpath = f'{ROOT}/resources/medex/Medex_UIMA_1.3.8/bin:{ROOT}/resources/medex/Medex_UIMA_1.3.8/lib/*'
args_template = "java -Xmx1024m -cp {0} org.apache.medex.Main -i {1} -o {2} -b n -f y -d y -t n"

with tempfile.TemporaryDirectory() as basedir:
    # create medex input
    medex_input._generate_medex_inputs(parsed_trial, result)
    input_dir = os.path.join(basedir, 'inputs')
    os.makedirs(input_dir)
    with open(os.path.join(input_dir, 'medex_input.json'), 'w') as f:
        json.dump(result, f)
        
    # Run medex
    output_path = os.path.join(basedir, "outputs")
    os.makedirs(os.path.join(output_path, "data"))
    args = args_template.format(classpath, os.path.join(input_dir, 'medex_input.json'), 
                                os.path.join(output_path, "data"))
    print(args)
    # args = args_template
    p = subprocess.run(shlex.split(args))
    
    medex_output_parser = medex.MedexOutputParser(base_paths=[output_path])
    
    medex_output_parser.fill_medex_info(parsed_trial)

java -Xmx1024m -cp ./parsing_package/resources/medex/Medex_UIMA_1.3.8/bin:./parsing_package/resources/medex/Medex_UIMA_1.3.8/lib/* org.apache.medex.Main -i /tmp/user/19668/tmp7h7i_rno/inputs/medex_input.json -o /tmp/user/19668/tmp7h7i_rno/outputs/data -b n -f y -d y -t n
Loading configuration files ...
Processing file NCT02370680_arm_0.txt ...
Processing file NCT02370680_drug_0.txt ...
Processing file NCT02370680_drug_1_othernames.txt ...
Processing file NCT02370680_arm_1.txt ...
Processing file NCT02370680_drug_0_othernames.txt ...
Processing file NCT02370680_drug_1.txt ...
total time:57868


## Crit2Query parsing

In [12]:
from data_parsers.external_tools.criteria2query import create_nctids_file
from data_parsers import CriteriaOutputParser
import time

with tempfile.TemporaryDirectory() as basedir:# create input
    input_dir = os.path.join(basedir, 'inputs')
    os.makedirs(input_dir)
    with open(os.path.join(input_dir, 'crit_input.txt'), 'w') as f:
        f.write(parsed_trial['eligibility_criteria'])

    # Run crit2query
    output_path = os.path.join(basedir, "outputs")
    os.makedirs(os.path.join(output_path, "data"))
    args =  f"java -Xmx4096m -jar {ROOT}/resources/criteria2query.jar " + \
                f"--input {os.path.join(input_dir, 'crit_input.txt')} " + \
                f"--outputDir {os.path.join(output_path, 'data')}"
    print(args)
    # args = args_template
    p = subprocess.run(shlex.split(args))

    parsed_trial['ec_umls'] = CriteriaOutputParser.parse_crit_output_from_file(os.path.join(output_path, "data", "output.json"))

java -Xmx4096m -jar ./parsing_package/resources/criteria2query.jar --input /tmp/user/19668/tmp0ym1nay_/inputs/crit_input.txt --outputDir /tmp/user/19668/tmp0ym1nay_/outputs/data


[main] INFO edu.stanford.nlp.parser.lexparser.LexicalizedParser - Loading parser from serialized file edu/columbia/dbmi/ohdsims/model/wsjPCFG.ser.gz ... done [0.4 sec].
[main] INFO edu.stanford.nlp.ie.AbstractSequenceClassifier - Loading classifier from edu/columbia/dbmi/ohdsims/model/c2q_all_model_advanced.ser.gz ... done [0.7 sec].
[main] INFO edu.stanford.nlp.parser.lexparser.LexicalizedParser - Loading parser from serialized file edu/columbia/dbmi/ohdsims/model/wsjPCFG.ser.gz ... done [0.2 sec].
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator tokenize
[main] INFO edu.stanford.nlp.pipeline.TokenizerAnnotator - No tokenizer type provided. Defaulting to PTBTokenizer.
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator ssplit
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator pos
[main] INFO edu.stanford.nlp.tagger.maxent.MaxentTagger - Loading POS tagger from edu/stanford/nlp/models/pos-tagger/english-left3words/engli

## Disease Extraction

In [13]:
def get_disease_data(root_dir):
    mesh_file = os.path.join(root_dir, 'd2021.bin')
    disgenet_file = os.path.join(root_dir, 'disease_mappings.tsv')
    submesh_file = os.path.join(root_dir, "c2021.bin")
    return  MeshData(mesh_file, submesh_file, disgenet_file, filter_categ=["C","F01","F02","F03"])
mesh_dis_data = get_disease_data(DATA_DIR + 'disease_data/2021')

In [14]:
from data_parsers import DiseaseExtract
disease_matcher = DiseaseExtract(data_year=2021, data_dir=f'{DATA_DIR}/')
parsed_trial['mesh_ids'] = disease_matcher.get_disease_ids(parsed_trial)

## Drug mapping

In [15]:
from data_parsers import DrugMatcher, get_intervention_drug_ids

In [16]:
drug_matcher = DrugMatcher(data_paths={
        'drug_data': f'{DATA_DIR}/drug_data/drugs_all_03_04_21.pkl',
        'pubchem_synonyms': f'{DATA_DIR}/drug_data/pubchem-drugbankid-synonyms.json',
        'rxnorm2drugbank-umls': f'{DATA_DIR}/drug_data/rxnorm2drugbank-umls.pkl',
        'RXNCONSO': f'{DATA_DIR}/drug_data/RXNCONSO.RRF'
})

14315it [00:06, 2094.70it/s]
1077691it [00:03, 342358.26it/s]


In [17]:
interventions = parsed_trial['intervention']
for intervention in interventions:
    get_intervention_drug_ids(drug_matcher, intervention, parsed_trial)

## Outcome measures

In [18]:
from data_parsers import OutcomeMeasureExtract
outcome_extractor = OutcomeMeasureExtract(
    f'{DATA_DIR}/outcome_data/clusters-outcome-measures.txt')

outcome_extractor.load_phrase_models(f'{DATA_DIR}/outcome_data')

In [19]:
outcome_extractor.populate_cids(parsed_trial)

## UMLS Mapping

In [20]:
from data_parsers import UMLSConceptSearcher

umls_concept_searcher = UMLSConceptSearcher(api_key='4faabe39-74a6-4ba2-9480-e486fb07a5b2',
                                           version='2020AB', cache_dir=f'{DATA_DIR}/population_data/umls_search_cache')
umls_concept_searcher.search_term('cancer')

UMLS Concept Cache loaded, found 272990 entries


{'ui': 'C0006826',
 'concept': {'atoms': 'https://uts-ws.nlm.nih.gov/rest/content/2019AB/CUI/C0006826/atoms',
  'suppressible': False,
  'cvMemberCount': 0,
  'atomCount': 317,
  'definitions': 'https://uts-ws.nlm.nih.gov/rest/content/2019AB/CUI/C0006826/definitions',
  'classType': 'Concept',
  'relationCount': 109,
  'status': 'R',
  'majorRevisionDate': '09-03-2019',
  'name': 'Malignant Neoplasms',
  'attributeCount': 0,
  'ui': 'C0006826',
  'relations': 'https://uts-ws.nlm.nih.gov/rest/content/2019AB/CUI/C0006826/relations',
  'defaultPreferredAtom': 'https://uts-ws.nlm.nih.gov/rest/content/2019AB/CUI/C0006826/atoms/preferred',
  'semanticTypes': [{'uri': 'https://uts-ws.nlm.nih.gov/rest/semantic-network/2019AB/TUI/T191',
    'name': 'Neoplastic Process'}],
  'dateAdded': '09-30-1990'},
 'uri': 'https://uts-ws.nlm.nih.gov/rest/content/2019AB/CUI/C0006826',
 'name': 'Malignant Neoplasms',
 'rootSource': 'MTH'}

In [21]:
from data_parsers.umls_utils import UMLSUtils
umls_utils = UMLSUtils(f'{DATA_DIR}/population_data/umls-install/2020AB/')

7964596it [00:09, 831797.65it/s]


In [22]:
umls_utils.load_relations()

29124214it [01:30, 322639.26it/s]


In [23]:
parsed_trial.keys()

dict_keys(['nct_id', 'arm_group', 'intervention', 'condition', 'intervention_mesh_terms', 'event_groups', 'primary_outcome', 'secondary_outcome', 'eligibility_criteria', 'brief_summary', 'phase', 'enrollment', 'gender_sex', 'minimum_age', 'maximum_age', 'clinical_results', 'medex_raw', 'medex_processed', 'ec_umls', 'mesh_ids'])

In [24]:
umls_concept_searcher.set_umls_search(False)
criteria_all = parsed_trial['ec_umls']
for category in criteria_all:
    for inclusion in criteria_all[category]:
        for criterion in criteria_all[category][inclusion]:
            criterion.map_concept(umls_concept_searcher)

### UMLS parents cutoff

In [25]:
umls_utils.cuid2parents = {}
criteria_all = parsed_trial['ec_umls']
for category in criteria_all:
    for inclusion in criteria_all[category]:
        for criterion in criteria_all[category][inclusion]:
            if criterion.concept is not None:
                criterion.parents = umls_utils.parents(criterion.concept['ui'])

### TF-IDF matcher

In [26]:
from data_parsers import UMLSTFIDFMatcher

tfidf_matcher = UMLSTFIDFMatcher(umls_utils.cuid2concept, f'{DATA_DIR}/population_data/', None)

In [27]:
tfidf_matcher.populate_result_single(parsed_trial['ec_umls'])

### Map to parents based on frequency

In [28]:
from data_parsers.population_extract import UMLSGraphClipper

In [29]:
filepath = f"{DATA_DIR}/population_data/umls_graph_clipper_output.pkl"
with open(filepath, "rb") as f:
    g_clipper_state = pickle.load(f)
    cuid2term = g_clipper_state['cuid2term'] 

## Load KG

In [30]:
entity2cid_path = f'{DATA_DIR}/kg_data/kg-entity2cid-31_7_21.pkl'
with open(entity2cid_path, 'rb') as f:
    entity2cid = pickle.load(f)

In [31]:
node_feats_path = f'{DATA_DIR}/kg_data/node_features_armtext.pkl'
with open(node_feats_path, 'rb') as f:
    node_feats = pickle.load(f)

In [32]:
node_feats.head(1)

Unnamed: 0,node_id,emb,etype
0,KG00000000,"[-0.031794183, 0.023484211, -0.7169324, -0.145...",DISEASE


In [33]:
from knowledge_graph import KnowledgeGraphBuilder, BioKG
from knowledge_graph.kg import Entity,EntityKey, EntityType, Source, Relation, UnionFind
from knowledge_graph.build_graph import TrialGraphBuilder

In [34]:
ext_basepath = f'{DATA_DIR}/kg_data/external_data'
builder = KnowledgeGraphBuilder(disease_matcher.mesh_dis_data, drug_matcher.drug_data, ext_basepath, 
                                cuid2term, umls_utils, umls_graph_clip_threshold=10,
                                build_ae=False)

./parsing_package/data//kg_data/external_data/go/go-basic.obo: fmt(1.2) rel(2021-02-01) 47,291 GO Terms; optional_attrs(relationship xref)


In [35]:
builder.biokg = BioKG()
builder._disease_disease()
builder._mesh_children()
builder.biokg = BioKG()

100%|███████████████████████████████████| 5751/5751 [00:00<00:00, 113489.70it/s]


In [36]:
builder.biokg = BioKG()
builder._disease_disease()
builder._drug_drug()
builder._protein_protein()
builder._function_function()
builder._protein_function()
builder._drug_protein()
builder._drug_phenotype()
builder._disease_gene()

# trials Stuff
builder._umls()
if builder.build_ae:
    builder._ae_ae()
    builder._ae_protein()
builder._primary_outcomes(f'{DATA_DIR}/outcome_data/clusters-outcome-measures.txt')

100%|███████████████████████████████████| 5751/5751 [00:00<00:00, 111978.69it/s]
14315it [00:01, 7769.65it/s]
100%|████████████████████████████████| 387626/387626 [00:22<00:00, 17053.53it/s]
100%|██████████████████████████████████| 47291/47291 [00:00<00:00, 49382.89it/s]
100%|███████████████████████████████| 118856/118856 [00:00<00:00, 256342.75it/s]
100%|█████████████████████████████████| 22477/22477 [00:00<00:00, 104949.23it/s]


1084


100%|█████████████████████████████████| 83652/83652 [00:00<00:00, 132118.04it/s]


1741


100%|██████████████████████████████| 242889/242889 [00:00<00:00, 1980378.57it/s]
100%|█████████████████████████████████| 84038/84038 [00:00<00:00, 130348.28it/s]


12807 6655 3189


7964596it [00:43, 184455.24it/s]


In [37]:
parsed_trial['has_results'] = False

In [38]:
trial_builder = TrialGraphBuilder(builder, parsed_trial)

In [39]:
trial_builder.build(use_population=True)

In [40]:
trial_builder.arm_labels

{'durlaza™, 1 capsule': 0, 'durlaza™, 2 capsules': 1}

In [41]:
uf = UnionFind()
cnt = 0
for u, v, data in builder.biokg.graph.edges(data=True):
    if data['relation'] == 'KG-MERGE-SAME':
        cnt += 1
        uf.union(u, v)

In [42]:
cnt

7732

In [43]:
from knowledge_graph.node_features import TrialAttributeFeatures

trial_attribute_featurizer = TrialAttributeFeatures(attributes=('age', 'gender', 'enrollment', 'phase'))

def _phase_feature_vec(phases):
    v = [0] * 5
    for phase in phases:
        if phase in ['EARLY_PHASE1', 'PHASE1']:
            v[1] = 1
        elif phase == 'N/A':
            v[0] = 1
        elif phase == 'PHASE2':
            v[2] = 1
        elif phase == 'PHASE3':
            v[3] = 1
        elif phase == 'PHASE4':
            v[4] = 1
#         elif phase == 'Phase 1/Phase 2':
#             v[1] = 1
#             v[2] = 1
#         elif phase == 'Phase 2/Phase 3':
#             v[2] = 1
#             v[3] = 1
        else:
            raise RuntimeError(f"Unknown phase: {phase}")
    return v

def _enrollment_feat(enrollment):
    is_anticipated = False
    if type(enrollment) == dict:
        if enrollment['type'] == 'ANTICIPATED':
            is_anticipated = True
        return [math.log(1 + enrollment['count']), int(is_anticipated)]
    if np.isnan(enrollment):
        return [0, 0]
    return [math.log(1 + enrollment), 0]


def _sex_vec(sex):
    if sex is None or type(sex) == float:
        return [0, 0, 0]
    sex_to_feats = {
        'ALL': [1, 0, 0],
        "MALE": [0, 1, 0],
        "FEMALE": [0, 0, 1]
    }
    return sex_to_feats[sex]
    
def features(self, trial_row):
    data = {}
    data['phase_vec'] = _phase_feature_vec(trial_row['phase'])
    data['enrollment_vec'] = _enrollment_feat(trial_row['enrollment'])
    data['gender_sex_vec'] = _sex_vec(trial_row['gender_sex'])
    data['minimum_age_vec'] = self._age_vec(trial_row['minimum_age'] or 0.0)
    data['maximum_age_vec'] = self._age_vec(trial_row['maximum_age'] or 0.0)

    def merge_vecs(row):
        feats = []
        for attribute in self.attributes:
            if attribute == 'phase':
                feats.extend(row['phase_vec'])
            elif attribute == 'enrollment':
                feats.extend(row['enrollment_vec'])
            elif attribute == 'gender':
                feats.extend(row['gender_sex_vec'])
            elif attribute == 'age':
                feats.extend(row['minimum_age_vec'])
                feats.extend(row['maximum_age_vec'])
            elif attribute == 'age_class':
                feats.extend(row['age_vec_2'])
            else:
                raise RuntimeError(f"Unknown attributes ({attribute}) for features")
        return np.array(feats)
    return merge_vecs(data)

trial_attribute_feats = features(trial_attribute_featurizer, parsed_trial)

In [44]:
def get_arm_text(row):
    arm2text = {}
    nct2text = {}
    summary = row['brief_summary']
    disease_text = ''
    for disease in row['condition']:
        disease_text += disease + " "
    outcome_text = ''
    if type(row['primary_outcome']) != float:
        for pom in row.get('primary_outcome', []):
            outcome_text += pom.get('measure', '') + " "
    criteria = row['eligibility_criteria']
    if type(criteria) == float:
        criteria = ''
        
    arm2intervention = {}
    for intervention in row['intervention']:
        intervention_text = intervention['intervention_name'] + ' '
        intervention_desc = intervention.get('description', '') + ' '
        arm_group_label = intervention.get('arm_group_label', ['default'])
        if not isinstance(arm_group_label, list):
            arm_group_label = ['default'] 
        for arm_label in arm_group_label:
            arm_label = arm_label.lower()
            arm2intervention[arm_label] = (intervention_text, intervention_desc)
          
    arms = row['arm_group']
    if not isinstance(arms, list):
        arms = [{'arm_group_label': 'default', 'arm_group_type': ''}]
    
    for idx, arm in enumerate(arms):
        arm_text = arm['arm_group_label'] + " " + arm.get('description', '')
        if arm['arm_group_label'].lower() in arm2intervention:
            intervention_text, intervention_desc = arm2intervention[arm['arm_group_label'].lower()]
        else:
            intervention_text, intervention_desc = '', ''
        all_text = " ".join([intervention_text, disease_text, outcome_text, arm_text, summary, intervention_desc, criteria])
        arm2text[row['nct_id'], idx] = all_text
        nct2text[row['nct_id']] = [disease_text, outcome_text, summary, criteria]
    return arm2text, nct2text
    

In [45]:
trial_data = []
arm2text, _ = get_arm_text(parsed_trial)
for arm_label, arm_idx in trial_builder.arm_labels.items():
    trial_arm_data = []
    for u, v, k, data in builder.biokg.graph.edges(nbunch=[trial_builder.arm_key(arm_idx)], data=True, keys=True):
#         print(u, v, k, data['relation'], entity2cid[uf.find_parent(v)])
        trial_arm_data.append({
            'kg_id': entity2cid[uf.find_parent(v)],
            'relation': data['relation'],
            'key': k,
            'data': data
        })
    trial_data.append({
        'nct_id': parsed_trial['nct_id'],
        'arm_label': arm_label,
        'arm_idx': arm_idx,
        'trial_arm_edges': trial_arm_data,
        'arm_text': arm2text[parsed_trial['nct_id'], arm_idx],
        'trial_attribute_feats_vec': trial_attribute_feats
    })
    
#     if data['relation'] == 'KG-MERGE-SAME':
#         continue
#     cidu = entity2concept[uf.find_parent(u)]
#     cidv = entity2concept[uf.find_parent(v)]
#     if not g.has_edge(cidu, cidv, k):
#         g.add_edge(cidu, cidv, key=k, relation=data['relation'], source=data['source'], attrs=[data['attrs']],
#                    extra_attrs={})
#     elif data['attrs']:
#         g.edges[cidu, cidv, k]['attrs'].append(data['attrs'])

In [46]:
trial_data

[{'nct_id': 'NCT02370680',
  'arm_label': 'durlaza™, 1 capsule',
  'arm_idx': 0,
  'trial_arm_edges': [{'kg_id': 'KG00000863',
    'relation': 'study-disease',
    'key': Relation(name='study-disease', source=<Source.CLINICAL_TRIAL: (9,)>),
    'data': {'relation': 'study-disease',
     'source': <Source.CLINICAL_TRIAL: (9,)>,
     'attrs': {}}},
   {'kg_id': 'KG00122148',
    'relation': 'primary_outcome',
    'key': Relation(name='primary_outcome', source=<Source.CLINICAL_TRIAL: (9,)>),
    'data': {'relation': 'primary_outcome',
     'source': <Source.CLINICAL_TRIAL: (9,)>,
     'attrs': {}}},
   {'kg_id': 'KG00020807',
    'relation': 'arm_tests_drug',
    'key': Relation(name='arm_tests_drug', source=<Source.CLINICAL_TRIAL: (9,)>),
    'data': {'relation': 'arm_tests_drug',
     'source': <Source.CLINICAL_TRIAL: (9,)>,
     'attrs': {}}},
   {'kg_id': 'KG00074632',
    'relation': 'eligibility-exclusion',
    'key': Relation(name='eligibility-exclusion', source=<Source.CLINICAL_TR

In [47]:
with open(f'{ROOT}/tmp/trial_data_NCT02370680.pkl', 'wb') as f:
    pickle.dump(trial_data, f)