In [None]:
from openai import OpenAI
from pathlib import Path
from tqdm import tqdm
from prompts import *
import pandas as pd
import tiktoken
import pickle
import json
import time
import re

ddi_subset = pd.read_csv("../data/mined_data/final_DDI.csv")
dpi_subset = pd.read_csv("../data/mined_data/final_DPI.csv")

all_drugs = set(ddi_subset["drug_1_name"].unique()).union(set(ddi_subset["drug_2_name"].unique()))
all_proteins = dpi_subset["protein_name"].unique()

client = OpenAI()
gpt_tokenizer = tiktoken.encoding_for_model("gpt-4o")

# Wiki Complexify 

In [None]:
def create_text_dict(entity_list, entity_type):
    entity_text = {}
    for entity in entity_list:
        with Path(f"../data/background_information_data/{entity_type}_data/Wiki/{entity}.txt").open("r") as file:
            text = file.read()
            entity_text[entity] = (text, len(gpt_tokenizer.encode(text)))
    return entity_text

def create_formatted_inputs_for_complexify(entity, text):
    return {"custom_id": f"{entity}-complexify", 
            "method": "POST", 
            "url": "/v1/chat/completions", 
            "body": {"model": "gpt-4o", 
                     "messages": [
                                     {"role": "developer", "content": WIKI_COMPLEXIFY_DEVELOPER_PROMPT},
                                     {"role": "user", "content": WIKI_COMPLEXIFY_USER_PROMPT.format(entity, text)}
                                 ]
                    }
           }

In [None]:
drugs_text = create_text_dict(all_drugs, "drug")

formatted_samples = []
for drug, text_tup in drugs_text.items():
    if text_tup[1] >= 200:
        formatted_samples.append(create_formatted_inputs_for_complexify(drug, text_tup[0]))

In [None]:
total_tokens = 0
for sample in formatted_samples:
    total_tokens += len(gpt_tokenizer.encode(sample["body"]["messages"][0]["content"]))
    total_tokens += len(gpt_tokenizer.encode(sample["body"]["messages"][1]["content"]))
if total_tokens < 90_000:
    print(f"Total tokens : {total_tokens}. Fine for batching everything.")
    with Path("../data/OAI/complexify/batch_input.jsonl").open('w') as file:
        for sample in formatted_samples:
            json_line = json.dumps(sample)
            file.write(json_line + '\n')

In [None]:
# If I want to read the file
with Path("../data/OAI/complexify/batch_input.jsonl").open('r') as file:
     s = [json.loads(line) for line in file]

# Molecular Interactions

In [None]:
def create_formatted_inputs_for_MI(row):
    drug_1_name = row.drug_1_name
    drug_2_name = row.drug_2_name
    drug_1_SMILES = row.drug_1_SMILES
    drug_2_SMILES = row.drug_2_SMILES
    return {"custom_id": f"{drug_1_name}-{drug_2_name}-MI",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {"model": "gpt-4o",
                     "messages": [
                                     {"role": "developer", "content": MOLECULAR_INTERACTIONS_DEVELOPER_PROMPT},
                                     {"role": "user", "content": MOLECULAR_INTERACTIONS_USER_PROMPT.format(drug_1_SMILES, drug_2_SMILES)}
                                 ]
                    }
           }

In [None]:
def save_batch(batch, idx):
    with Path(f"../data/OAI/molecular_interactions/batch_{idx}_input.jsonl").open('w') as file:
        for sample in batch:
            json_line = json.dumps(sample)
            file.write(json_line + '\n')

In [None]:
formatted_samples = []
for row in ddi_subset.itertuples(index=False):
    formatted_samples.append(create_formatted_inputs_for_MI(row))

In [None]:
batch_tokens = 0
batch_id = 0
batch = []
i = 0
while i < len(formatted_samples):
    sample = formatted_samples[i]
    batch.append(sample)
    batch_tokens = batch_tokens + \
                   len(gpt_tokenizer.encode(sample["body"]["messages"][0]["content"])) + \
                   len(gpt_tokenizer.encode(sample["body"]["messages"][1]["content"]))
    if batch_tokens > 90_000:
        batch.pop() # Removing the last sample which caused the total number of tokens to exceed the 90K limit.
        save_batch(batch, batch_id)
        batch_id += 1
        batch = []
        batch_tokens = 0
    else:
        i += 1
save_batch(batch, batch_id)

In [None]:
# Sanity check output
with Path("../data/OAI/molecular_interactions/batch_0_input.jsonl").open('r') as file:
     s = [json.loads(line) for line in file]

with Path("../data/OAI/molecular_interactions/batch_1_input.jsonl").open('r') as file:
     y = [json.loads(line) for line in file]

assert len(s+y) == len(formatted_samples)