In [1]:
import sys
import time
from tqdm import tqdm
from openai import OpenAI, AzureOpenAI
import os

from code.config import cfg, update_cfg
from code.utils import time_logger
from code.query_chatgpt import num_tokens_from_messages, get_context_window_size_limit
from code.data_utils.utils import (load_message, check_cache_response,
                                   save_chatcompletion, save_response,
                                   load_chatcompletion, load_response,
                                   clean_cache_chat_completion_response)
from code.data_utils.dataset import DatasetLoader
from code.query_chatgpt_explainer import query_chatgpt_batch

In [2]:
# manual cfg settings
cfg.dataset = "ogbg-molbace" # ogbg-molhiv
cfg.llm.template = "EP"
# cfg.llm.model.name = "gpt-4-1106-preview" # gpt-3.5-turbo-1106, gpt-4-1106-preview 
cfg.demo_test = True

In [3]:
# RPM limit (adjust according to your plan)
rpm_limit = 3500
# TPM limit (adjust according to your plan)
tpm_limit = 60000
# Context window size
cws_limit = get_context_window_size_limit(cfg.llm.model.name)

# Set up OpenAI API
client = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version="2023-12-01-preview",
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
) if cfg.llm.provider == 'aoai' else OpenAI(api_key=cfg.OPENAI_API_KEY)

message_list = load_message(
    dataset_name=cfg.dataset, message_type=cfg.llm.template,
    gnn_model=cfg.gnn.model.name, seed=cfg.seed,
    demo_test=cfg.demo_test
)
if cfg.demo_test:
    message_list = message_list[:cfg.num_sample]

In [4]:
message_list[0]

[{'role': 'system',
  'content': 'You are an AI molecule property analysis assistant specializing in explaining predictions made by Graph Neural Networks (GNNs). The prediction is about whether the molecule inhibits human β-secretase 1(BACE-1).'},
 {'role': 'user',
  'content': "The molecule-241's SMILES string is s1cc(cc1)-c1cc2c(nc(N)cc2)cc1. The molecule-241's prediction given by GNNs is False with predicted probability 0.9668. The true label of the molecule-241 is True. Indentify insightful features to explain why GNNs make this prediction for molecule-241."}]

In [5]:
# Save all queries
chat_completion_list = []
response_list = []

# Run batch queries
batch_message_list = []
batch_message_token_num = 0
batch_start_id = 0
display = "Query {} {}".format(cfg.dataset, cfg.llm.model.name)
for message_id, message in enumerate(tqdm(message_list, desc=display)):

    num_tokens = num_tokens_from_messages(
        messages=message, original_model=cfg.llm.model.name
    )
    if num_tokens > tpm_limit:
        sys.exit("Message token number is large than limit {}.".format(tpm_limit))
    while num_tokens >= cws_limit:
        print("Message context length is {}, larger than Context Window Size limit {}.".format(
            num_tokens, cws_limit
        ))
        str_len = len(message[1]["content"])
        message[1]["content"] = (message[1]["content"][:int(str_len / 3)] +
                                 message[1]["content"][-int(str_len / 3):])
        num_tokens = num_tokens_from_messages(
            messages=message, original_model=cfg.llm.model.name
        )
        print("Message token number is reduced to {}.".format(num_tokens))
    batch_message_token_num += num_tokens

    if (batch_message_token_num >= tpm_limit) and (message_id < len(message_list) - 1):

        batch_chat_completion_list, batch_response_list = query_chatgpt_batch(
            client=client, dataset_name=cfg.dataset,
            llm_model=cfg.llm.model.name, template=cfg.llm.template,
            gnn_model=cfg.gnn.model.name, seed=cfg.seed,
            batch_message_list=batch_message_list, batch_start_id=batch_start_id,
            rpm_limit=rpm_limit
        )
        chat_completion_list += batch_chat_completion_list
        response_list += batch_response_list
        batch_message_list = [message]
        batch_message_token_num = num_tokens_from_messages(
            messages=message, original_model=cfg.llm.model.name
        )
        batch_start_id = message_id

    elif message_id == len(message_list) - 1:
        batch_message_list.append(message)

        batch_chat_completion_list, batch_response_list = query_chatgpt_batch(
            client=client, dataset_name=cfg.dataset,
            llm_model=cfg.llm.model.name, template=cfg.llm.template,
            gnn_model=cfg.gnn.model.name, seed=cfg.seed,
            batch_message_list=batch_message_list, batch_start_id=batch_start_id,
            rpm_limit=rpm_limit
        )
        chat_completion_list += batch_chat_completion_list
        response_list += batch_response_list

    else:
        batch_message_list.append(message)
        
# Save all chat completion
save_chatcompletion(
    dataset_name=cfg.dataset, chat_completion=chat_completion_list,
    gnn_model=cfg.gnn.model.name, seed=cfg.seed,
    template=cfg.llm.template, demo_test=cfg.demo_test
)
# Save all responses
save_response(
    dataset_name=cfg.dataset, list_response=response_list,
    gnn_model=cfg.gnn.model.name, seed=cfg.seed,
    template=cfg.llm.template, demo_test=cfg.demo_test
)
# Clean all cache files
clean_cache_chat_completion_response(
    dataset_name=cfg.dataset, template=cfg.llm.template,
    gnn_model=cfg.gnn.model.name, seed=cfg.seed,
)

Query ogbg-molbace gpt-3.5-turbo-1106:   5%|▌         | 1/20 [00:00<00:07,  2.56it/s]

Created directory /home/zhiqiang/LLMaGML/output/chat_completion/ogbg-molbace/cache_chat_completion/EP-gin-v-42/
Created directory /home/zhiqiang/LLMaGML/output/response/ogbg-molbace/cache_response/EP-gin-v-42/


Query ogbg-molbace gpt-3.5-turbo-1106: 100%|██████████| 20/20 [02:22<00:00,  7.13s/it]
