In [1]:
import sys
import time
from tqdm import tqdm
from openai import OpenAI, AzureOpenAI
import tiktoken
import os
import torch

from code.config import cfg, update_cfg
from code.utils import time_logger
from code.query_chatgpt import num_tokens_from_messages, get_context_window_size_limit
from code.data_utils.utils import (load_message, check_cache_response,
                                   load_gnn_predictions,
                                   save_chatcompletion, save_response,
                                   load_chatcompletion, load_response,
                                   clean_cache_chat_completion_response)
from code.data_utils.dataset import DatasetLoader
from code.query_chatgpt_corrector import query_chatgpt_batch

In [2]:
# manual cfg settings
cfg.dataset = "ogbg-molbace" # ogbg-molhiv
cfg.llm.template = "CorrFSC-30"
cfg.gnn.model.name = "gin-v"
cfg.seed = 42
# cfg.llm.model.name = "gpt-4-1106-preview" # gpt-3.5-turbo-1106, gpt-4-1106-preview 
cfg.demo_test = True
if cfg.dataset == "ogbg-molbace":
  demo_list = [101, 102, 103, 201, 202, 0, 1, 6, 239, 240] # bace
elif cfg.dataset == "ogbg-molbbbp":
  demo_list = [422, 313, 354, 370, 120, 6, 291, 94, 8, 453] # bbbp
elif cfg.dataset == "ogbg-molhiv":
  demo_list = [8773, 1975, 3969, 9063, 6750, 7305, 2191, 7171, 2213, 2190] # hiv
else:
  demo_list = [101, 102, 103]

In [3]:
# Preprocess data
dataloader = DatasetLoader(name=cfg.dataset, text='raw')
dataset, smiles = dataloader.dataset, dataloader.text

split_idx = dataset.get_idx_split()
test_indices = split_idx["test"].numpy()
predictions = torch.sigmoid(load_gnn_predictions(
    dataset_name=cfg.dataset, gnn_model_name=cfg.gnn.model.name,
    feature=cfg.data.feature, lm_model_name=cfg.lm.model.name, seed=cfg.seed
)).squeeze().numpy()
labels = dataset.y.squeeze().numpy()

In [4]:
# RPM limit (adjust according to your plan)
rpm_limit = 3500
# TPM limit (adjust according to your plan)
tpm_limit = 60000
# Context window size
cws_limit = get_context_window_size_limit(cfg.llm.model.name)

# Set up OpenAI API
client = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version="2023-12-01-preview",
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
) if cfg.llm.provider == 'aoai' else OpenAI(api_key=cfg.OPENAI_API_KEY)

full_message_list = load_message(
    dataset_name=cfg.dataset, message_type=cfg.llm.template,
    gnn_model=cfg.gnn.model.name, seed=cfg.seed,
    demo_test=cfg.demo_test
)
message_list = [full_message_list[id] for id in demo_list] if cfg.demo_test \
    else [full_message_list[id] for id in test_indices]

In [5]:
message_list[0]

[{'role': 'system',
  'content': 'You are an expert in Graph Machine Learning, specializing in correcting predictions made by Graph Neural Networks (GNNs). The prediction is about whether the molecule inhibits human β-secretase 1(BACE-1).'},
 {'role': 'user',
  'content': "The molecule-645's SMILES string is Fc1cc(cc(F)c1)CC(NC(=O)C)C(O)C[NH2+]C1(CCC(=O)NC1)c1cc(ccc1)C(C)(C)C, and its prediction given by the GNN model is False with predicted probability 0.8573. The molecule is an organic cation obtained by protonation of the secondary amino group of rolapitant. It is an ammonium ion derivative and an organic cation. It is a conjugate acid of a rolapitant. Provide corrected prediction."},
 {'role': 'assistant', 'content': 'Prediction: True.'},
 {'role': 'user',
  'content': "The molecule-713's SMILES string is O(c1cc(ccc1OC)C1(N=C(N)N(C)C1=O)c1cc(ccc1)-c1cccnc1)C1CCCC1, and its prediction given by the GNN model is False with predicted probability 0.5115. The molecule is a member of the 

In [6]:
# Save all queries
chat_completion_list = []
response_list = []

# Run batch queries
batch_message_list = []
batch_message_token_num = 0
batch_start_id = 0
display = "Query {} {}".format(cfg.dataset, cfg.llm.model.name)
for message_id, message in enumerate(tqdm(message_list, desc=display)):

    num_tokens = num_tokens_from_messages(
        messages=message, original_model=cfg.llm.model.name
    )
    if num_tokens > tpm_limit:
        sys.exit("Message token number is large than limit {}.".format(tpm_limit))
    if num_tokens >= cws_limit:
        print("Message context length is {}, larger than Context Window Size limit {}.".format(
            num_tokens, cws_limit
        ))
        print("Reducing message...")
        instruction, knowledge, question = message[0], message[1:-1], message[-1]
        
        while num_tokens >= cws_limit:
            knowledge = knowledge[:-1]
            message = [instruction] + knowledge + [question]
            num_tokens = num_tokens_from_messages(
                messages=message, original_model=cfg.llm.model.name
            )
        print("Message token number is reduced to {}.".format(num_tokens))
    batch_message_token_num += num_tokens

    if (batch_message_token_num >= tpm_limit) and (message_id < len(message_list) - 1):

        batch_chat_completion_list, batch_response_list = query_chatgpt_batch(
            client=client, dataset_name=cfg.dataset,
            llm_model=cfg.llm.model.name, template=cfg.llm.template,
            gnn_model=cfg.gnn.model.name, seed=cfg.seed,
            batch_message_list=batch_message_list, batch_start_id=batch_start_id,
            rpm_limit=rpm_limit
        )
        chat_completion_list += batch_chat_completion_list
        response_list += batch_response_list
        batch_message_list = [message]
        batch_message_token_num = num_tokens_from_messages(
            messages=message, original_model=cfg.llm.model.name
        )
        batch_start_id = message_id

    elif message_id == len(message_list) - 1:
        batch_message_list.append(message)

        batch_chat_completion_list, batch_response_list = query_chatgpt_batch(
            client=client, dataset_name=cfg.dataset,
            llm_model=cfg.llm.model.name, template=cfg.llm.template,
            gnn_model=cfg.gnn.model.name, seed=cfg.seed,
            batch_message_list=batch_message_list, batch_start_id=batch_start_id,
            rpm_limit=rpm_limit
        )
        chat_completion_list += batch_chat_completion_list
        response_list += batch_response_list

    else:
        batch_message_list.append(message)
        
# Save all chat completion
save_chatcompletion(
    dataset_name=cfg.dataset, chat_completion=chat_completion_list,
    gnn_model=cfg.gnn.model.name, seed=cfg.seed,
    template=cfg.llm.template, demo_test=cfg.demo_test
)
# Save all responses
save_response(
    dataset_name=cfg.dataset, list_response=response_list,
    gnn_model=cfg.gnn.model.name, seed=cfg.seed,
    template=cfg.llm.template, demo_test=cfg.demo_test
)
# Clean all cache files
clean_cache_chat_completion_response(
    dataset_name=cfg.dataset, template=cfg.llm.template,
    gnn_model=cfg.gnn.model.name, seed=cfg.seed,
)

Query ogbg-molbace gpt-3.5-turbo-1106:  10%|█         | 1/10 [00:00<00:03,  2.53it/s]

Created directory /home/zhiqiang/LLMaGML/output/chat_completion/ogbg-molbace/cache_chat_completion/CorrFSC-30-gin-v-42/
Created directory /home/zhiqiang/LLMaGML/output/response/ogbg-molbace/cache_response/CorrFSC-30-gin-v-42/


Query ogbg-molbace gpt-3.5-turbo-1106: 100%|██████████| 10/10 [00:35<00:00,  3.57s/it]


In [7]:
for id, (index, response) in enumerate(zip(demo_list, response_list)):
    pred, label = predictions[index], labels[index]
    print("Response: {}".format(response))
    print("{} Molecule ID: {}; Prediction: {:.4f}; Label: {}.\n".format(id, index, pred, label))

Response: Prediction: False; 
Explanation: The provided SMILES string is not matching with the description of N-[(2,3-dihydro-1,4-benzothiazol-2-yl)methyl]-N-(2,3-dihydro-1,4-benzothiazol-2-yl)-L-methionine, so the given prediction is not applicable.

0 Molecule ID: 101; Prediction: 0.9729; Label: 0.

Response: Prediction: False; Explanation: The correct prediction of the molecule-102 is False with a predicted probability of 0.9927. Although the provided SMILES string belongs to amorolfine, an antifungal medication that is typically used topically for the treatment of onychomycosis, the predicted label seems to be incorrect based on the context provided. Therefore, it's essential to revisit and verify the prediction considering the presented information regarding the role and nature of the molecule.
1 Molecule ID: 102; Prediction: 0.9927; Label: 0.

Response: Prediction: True; Explanation: The molecule exhibits anti-HIV-1 activity and is a cyclodepsipeptide derivative, indicating poten