\
## Key steps:
### 1. Convert a user's question into a TRAPI json
### 2. Validate the TRAPI json format
### 3. Refine the TRAPI json format by selecting the similar categories and predicates
### 4. ID formating 
### 5. Query, rank, and visulization

In [None]:
import sys
sys.path.append('../src')
import TCT as TCT
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import openai
import json
import ipywidgets as widgets
from IPython.display import display
# confidential keys for openAI

#replace this with your own confidential key
confi = json.loads(open("confidential_key.json").read())

openai.organization = confi['Organization']
openai.api_key = confi['API_key']


In [None]:
# Step1: List all the APIs in the translator system
APInames = TCT.list_Translator_APIs()
print(len(APInames))
#print(APInames.keys())

# Step 2: Get metaKG from Translator APIs
metaKG = TCT.get_KP_metadata(APInames) # This only applies to the Translator APIs
#print(metaKG.columns)
#print(metaKG.shape)

metaKG.to_csv('../metaData/metaKG.csv', index=False)
metaKG = pd.read_csv('../metaData/metaKG.csv')
All_predicates = list(set(metaKG['KG_category']))

KG_category = list(set( list(metaKG['Subject'].unique()) + list(metaKG['Object'].unique())))
KG_predicates = list(metaKG['KG_category'].unique())

# Step 3: Load the query template
query_json_temp = TCT.load_json_template()
query_json = str(query_json_temp)

In [None]:
def convert_Question2Query(question):
    # formatting query text
    input_text = "We know the avaiable predicates in the KG are: " + ','.join(list(set(KG_predicates))) + ". We also know the avaiable categories in the KGs are "+ ','.join(list(set(KG_category))) +". We also know a TRAPI message temple is " +query_json+ ". With the question of " + question + " What is the json format of message to represent this question? The following rules for the output: 1) The result must be just a json format with the same structure with template; 2)categories should be replaced from the categories in the KG; 3)predicts can be replaced from the predicates in the KG; 4) can use the name to fill the ids; 4)the output must start with '{' and end with '}', and be a standard json format.  At least one ids should be given and No annotations are needed!"
    query_json_cur = TCT.ask_chatGPT4(input_text)
    query_json_cur_clean = TCT.extract_json(query_json_cur)
    return query_json_cur_clean

In [None]:
# Example questions
question1 = "What genes or proteins interact with KRAS?"  #successful running
question2 = "What drugs may treat Type 2 diabetes?" #successful running
question3 = "what disease we can consider to treat with drug Olaparib?" #successful running
question4 = "What could be potential targets for ovarian cancer?" 
question5 = "What are the drugs or small moleculaes that target the gene KRAS?" #successful running
question6 = "What diseases occurence with covid-19?"  #successful running
question7 = "What symptoms are associated with long covid?" # no results
question8 = "What genes are associated with aptosis?"  # successful running
question9 = "What drug increase the risk of liver cancer?" # not successful, may because of the direction of the edge
question10 = "which drugs are in clinical trial for liver cancer?"


In [None]:
# Please input your question here
question = widgets.Textarea(
    value='',
    placeholder='Ask a queation to Translator',
    description='Question:',
    disabled=False,
    layout=widgets.Layout(width='80%', height='100px')
)
display(question)


In [None]:
print("The question you asked is : ")
print(question.value)
question= question.value


In [None]:
query_json_cur_clean = convert_Question2Query(question)
query_json_cur_clean

In [None]:
# add a widget to ask whether a user would like to refine the category or predicate
refine_category = widgets.RadioButtons( options=['Yes', 'No'], 
                                       value='No', 
                                       description='Refine category?', disabled=False)
display(refine_category)


In [None]:
if refine_category.value == 'Yes':

    similar_category = TCT.get_similar_category(query_json_cur_clean, KG_category)
    #print(TCT.find_similar_category(query_json_cur_clean, KG_category))

    # add a widget to select the category for n0
    
    print(query_json_cur_clean)
    category_n1 = widgets.SelectMultiple(
            options=similar_category,
            value=[],
            #rows=10,
            description='Node 1',
            disabled=False,
            layout=widgets.Layout(width='80%', height='300px')
        )
    display(category_n1)

    # add a widget to select the category for n1
    category_n2 = widgets.SelectMultiple(
        options=similar_category,
        value=[],
        #rows=10,
        description='Node 2',
        disabled=False,
        layout=widgets.Layout(width='80%', height='300px')
    )
    display(category_n2)

In [None]:
# update categories
if refine_category.value == 'Yes':
    if len(category_n1.value) > 0:
        print("updated node 1!")
        query_json_cur_clean['message']['query_graph']['nodes']['n0']['categories'] = list(category_n1.value)
    if len(category_n2.value) > 0:
        print("updated node 2!")
        query_json_cur_clean['message']['query_graph']['nodes']['n1']['categories'] = list(category_n2.value)
print(query_json_cur_clean)

In [None]:
# add a widget to ask whether a user would like to refine the category or predicate
refine_predicates = widgets.RadioButtons( options=['Yes', 'No'], 
                                       value='No', 
                                       description='Refine predicates?', disabled=False)
display(refine_predicates)


In [None]:
if refine_predicates.value == 'Yes':
    print(question)
    print(query_json_cur_clean)
    similar_predicate = TCT.get_similar_predicate(query_json_cur_clean, All_predicates)
    #print(TCT.find_similar_predicate(query_json_cur_clean, All_predicates))

    # add a widget to select the category for n0
    predicate_e01 = widgets.SelectMultiple(
            options=similar_predicate,
            value=[],
            #rows=10,
            description='Predicates',
            disabled=False,
            layout=widgets.Layout(width='80%', height='300px')
        )
    display(predicate_e01)

In [None]:
# update predicates
query_json_cur_clean['message']['query_graph']['edges']['e1']['predicates'] = list(predicate_e01.value)

In [None]:
# Validate the format of the json
print("The current json format is: ")
print(query_json_cur_clean)
TCT.TRAPI_json_validation(query_json_cur_clean, All_predicates, KG_category)

In [None]:
# id formatting
query_json_cur_clean = TCT.format_id(query_json_cur_clean)
print(query_json_cur_clean)

In [None]:
# Step Select APIs
input_node1_category = query_json_cur_clean['message']['query_graph']['nodes']['n0']['categories']
input_node2_category = query_json_cur_clean['message']['query_graph']['nodes']['n1']['categories']

sele_APIs = TCT.select_API(sub_list=input_node1_category,
                           obj_list=input_node2_category,
                           metaKG=metaKG)

print("all relevant APIs in Translator:")
print(sele_APIs)
print(len(sele_APIs))



In [None]:
# Still need to revise id?
#query_json_cur_clean['message']['query_graph']['nodes']['n0']['ids'] = ['MONDO:0008170']

In [None]:
# get API URLs
API_URLs = TCT.get_Translator_API_URL(sele_APIs, 
                                      APInames)

print(API_URLs)
# Step 5: Query Translator APIs and parse results
result = TCT.parallel_api_query(API_URLs,query_json=query_json_cur_clean, max_workers=len(API_URLs))

# Step 6: Parse results
result_parsed = TCT.parse_KG(result)


In [None]:
# Step 7: Ranking the results. This ranking method is based on the number of unique primary infores. It can only be used to rank the results with one defined node.
input_node1_id = query_json_cur_clean['message']['query_graph']['nodes']['n0']['ids'][0]
result_ranked_by_primary_infores = TCT.rank_by_primary_infores(result_parsed, input_node1_id)   # input_node1_id is the curie id of the input node, such as "NCBIGene:1017"


In [None]:
# print results 
print(result_ranked_by_primary_infores.shape)
result_ranked_by_primary_infores.head(5)

In [None]:
# Step 8: Visualize the results
TCT.visulization_one_hop_ranking(result_ranked_by_primary_infores, result_parsed, 
                                num_of_nodes = 30, input_query = input_node1_id, 
                                fontsize = 10)