In [28]:
from tqdm import tqdm
import openai

from code.config import cfg
from code.query_chatgpt import num_tokens_from_messages, query_chatgpt_batch
from code.utils import set_seed
from code.data_utils.dataset import DatasetLoader
from code.data_utils.utils import load_message, save_chatcompletion, save_response, clean_cache_chat_completion_response

In [29]:
# load cfg
set_seed(cfg.seed)

# manual cfg settings
cfg.dataset = "ogbg-molbace" # ogbg-molhiv
cfg.llm.template = "CorrFS-5"
cfg.llm.model.name = "gpt-3.5-turbo-1106"
cfg.demo_test = True

In [30]:
# RPM limit (adjust according to your plan)
rpm_limit = 500 if cfg.demo_test else 3500
# TPM limit (adjust according to your plan)
tpm_limit = 1000 if cfg.demo_test else 60000

# Set up OpenAI API
client = openai.OpenAI(api_key=cfg.OPENAI_API_KEY)

In [31]:
# Preprocess data
dataloader = DatasetLoader(name=cfg.dataset, text='raw')
dataset, smiles = dataloader.dataset, dataloader.text
message_list = load_message(
    dataset_name=cfg.dataset, message_type=cfg.llm.template, 
    demo_test=cfg.demo_test
)
if cfg.demo_test:
    message_list = message_list[:cfg.num_sample]

In [32]:
message_list[0]

[{'role': 'system',
  'content': 'You are an expert in Graph Machine Learning, specializing in correcting predictions made by Graph Neural Networks (GNNs). The prediction is about whether the molecule inhibits human β-secretase 1(BACE-1).'},
 {'role': 'user',
  'content': "The molecule's SMILES string is Clc1ccc(nc1)C(=O)Nc1cc(C2(N=C(N)c3c(C2)cccc3)C)c(F)cc1, and its prediction given by the GNN model is False with predicted probability 0.8050. Provide corrected prediction. "},
 {'role': 'assistant', 'content': 'True.'},
 {'role': 'user',
  'content': "The molecule's SMILES string is Fc1cc(cc(F)c1)CC(NC(=O)C(N1CCC(NC(=O)C)(CC2CC2)C1=O)C)C(O)C[NH2+]Cc1cc(OC)ccc1, and its prediction given by the GNN model is True with predicted probability 0.5800. Provide corrected prediction. "},
 {'role': 'assistant', 'content': 'True.'},
 {'role': 'user',
  'content': "The molecule's SMILES string is S1(=O)(=O)N(c2cc(cc3c2n(cc3CC)CC1)C(=O)N[C@H]([C@H](O)C[NH2+]C1CCOCC1)Cc1ccccc1)CC, and its prediction 

In [33]:
# Save all queries
chat_completion_list = []
response_list = []

# Run batch queries
batch_message_list = []
batch_message_token_num = 0
batch_start_id = 0
batch_end_id = 0
display = "Query {} {}".format(cfg.dataset, cfg.llm.model.name)
for message_id, message in enumerate(tqdm(message_list, desc=display)):
    
    batch_message_token_num += num_tokens_from_messages(
        messages=message, original_model=cfg.llm.model.name
    )
    
    if (batch_message_token_num >= tpm_limit) and (message_id < len(message_list) - 1):
        
        batch_chat_completion_list, batch_response_list = query_chatgpt_batch(
            client=client, dataset_name=cfg.dataset,
            llm_model=cfg.llm.model.name, template=cfg.llm.template,
            batch_message_list=batch_message_list, batch_start_id=batch_start_id,
            rpm_limit=rpm_limit
        )
        chat_completion_list += batch_chat_completion_list
        response_list += batch_response_list
        batch_message_list = [message]
        batch_message_token_num = num_tokens_from_messages(
            messages=message, original_model=cfg.llm.model.name
        )
        batch_start_id = message_id
    elif message_id == len(message_list) - 1:
        batch_message_list.append(message)
        
        batch_chat_completion_list, batch_response_list = query_chatgpt_batch(
            client=client, dataset_name=cfg.dataset,
            llm_model=cfg.llm.model.name, template=cfg.llm.template,
            batch_message_list=batch_message_list, batch_start_id=batch_start_id,
            rpm_limit=rpm_limit
        )
        chat_completion_list += batch_chat_completion_list
        response_list += batch_response_list
    else:
        batch_message_list.append(message)
        batch_end_id += 1

# Save all responses
save_chatcompletion(
    dataset_name=cfg.dataset, chat_completion=chat_completion_list,
    template=cfg.llm.template, demo_test=cfg.demo_test
)
# Save all responses
save_response(
    dataset_name=cfg.dataset, list_response=response_list,
    template=cfg.llm.template, demo_test=cfg.demo_test
)
# Clean all cache files
clean_cache_chat_completion_response(
    dataset_name=cfg.dataset, template=cfg.llm.template
)

Query ogbg-molbace gpt-3.5-turbo-1106:  10%|█         | 2/20 [00:03<00:30,  1.72s/it]

Created directory /home/zhiqiang/LLMaGML/output/chat_completion/ogbg-molbace/cache_chat_completion/CorrFS-5/
Created directory /home/zhiqiang/LLMaGML/output/response/ogbg-molbace/cache_response/CorrFS-5/


Query ogbg-molbace gpt-3.5-turbo-1106: 100%|██████████| 20/20 [00:41<00:00,  2.09s/it]


In [34]:
response_list[0]

'Prediction: True; Possibility: 0.9243'