In [None]:
from pathlib import Path
from prompts import *
import pandas as pd
import tiktoken
import pickle

In [None]:
# Loading everything

ddi_subset = pd.read_csv("../data/mined_data/final_DDI.csv")
dpi_subset = pd.read_csv("../data/mined_data/final_DPI.csv")

enc = tiktoken.encoding_for_model("gpt-4o")

In [None]:
all_drugs = set(ddi_subset["drug_1_name"].unique()).union(set(ddi_subset["drug_2_name"].unique()))
all_proteins = dpi_subset["protein_name"].unique()

# Summarization Cost Estimate

In [None]:
def get_longest_text(entity_type, entity_list):
    entity_folder_path = Path(f"../data/background_information_data/{entity_type}_data/PubMed")
    ent_lens = {}
    for entity in entity_list:
        with Path(entity_folder_path / f"{entity}.txt").open("r") as file:
            ent_lens[entity] = len(file.read())
    return max(ent_lens.items(), key = lambda x: x[1])

def get_input_prompts(entity, entity_type):
    abstracts_string = ""
    with Path(f"../data/background_information_data/{entity_type}_data/PubMed/{entity}.txt").open("r") as f:
        all_abstracts = f.readlines()
        if entity_type == "drug":
            for idx, ab in enumerate(all_abstracts):
                abstracts_string += f"Abstract {idx + 1}: {ab}\n"
            return DRUG_SUMMARIZATION_PROMPT.format(entity, abstracts_string.strip())
        else:
            for idx in range(1, len(all_abstracts)):
                abstracts_string += f"Abstract {idx}: {all_abstracts[idx]}\n"
            return PROTEIN_SUMMARIZATION_PROMPT.format(entity, all_abstracts[0], abstracts_string.strip())

def calculate_total_cost(input_prompt_tokens, output_prompt, total_ents):
    output_prompt_tokens = enc.encode(output_prompt)
    total_cost_for_one_drug = ((2.5 * len(input_prompt_tokens)) + (10 * len(output_prompt_tokens)))/1_000_000
    return len(total_ents) * total_cost_for_one_drug

In [None]:
longest_drug_text = get_longest_text("drug", all_drugs)
longest_protein_text = get_longest_text("protein", all_proteins)

longest_drug_input_prompt = get_input_prompts(longest_drug_text[0], "drug")
longest_protein_input_prompt = get_input_prompts(longest_protein_text[0], "protein")

longest_drug_input_prompt_tokens = enc.encode(longest_drug_input_prompt)
longest_protein_input_prompt_tokens = enc.encode(longest_protein_input_prompt)

with Path("../data/sample_data/longest_drug_summarization_output.txt").open("r") as file:
    drug_output = file.read()

with Path("../data/sample_data/longest_protein_summarization_output.txt").open("r") as file:
    protein_output = file.read()

drug_cost = calculate_total_cost(longest_drug_input_prompt_tokens, drug_output, all_drugs)
protein_cost = calculate_total_cost(longest_protein_input_prompt_tokens, protein_output, all_proteins)

print(f"Longest drug text input tokens: {len(longest_drug_input_prompt_tokens)}")
print(f"Longest protein text input tokens: {len(longest_protein_input_prompt_tokens)}\n")

print(f"Longest drug text output tokens: {len(drug_output)}")
print(f"Longest protein text output tokens: {len(protein_output)}\n")

print(f"Upper limit total cost for drug summarization: {round(drug_cost, 2)}")
print(f"Upper limit total cost for protein summarization: {round(protein_cost, 2)}\n")

print(f"Batching Upper limit total cost for drug summarization: {round(drug_cost/2, 2)}")
print(f"Batching Upper limit total cost for protein summarization: {round(protein_cost/2, 2)}")