In [None]:
import concurrent.futures
import json
import re

from openai import OpenAI
from tqdm import tqdm

from virtual_lab.constants import CONSISTENT_TEMPERATURE
from virtual_lab.run_meeting import run_meeting
from virtual_lab.utils import get_messages, get_pubmed_central_article

from nanobody_constants import (
    background_prompt,
    nanobody_prompt,
    discussions_phase_to_dir,
    model,
    finetuning_base_model,
    immunologist,
    machine_learning_specialist,
    computational_biologist,
)

In [None]:
# Topic to agent mapping
topic_to_agent = {
    "nanobodies": immunologist,
    "SARS-CoV-2 spike protein": immunologist,
    "SARS-CoV-2 variants KP.3 and JN.1": immunologist,
    "ESM": machine_learning_specialist,
    "AlphaFold-Multimer": computational_biologist,
    "Rosetta": computational_biologist,
}

In [None]:
# Agent fine-tuning queries
with concurrent.futures.ThreadPoolExecutor() as executor:
    concurrent.futures.wait([
        executor.submit(
            run_meeting,
            meeting_type="individual",
            team_member=agent,
            agenda=f"{background_prompt} {nanobody_prompt} You are responsible for understanding the topic \"{topic}\" in the context of designing nanobody binders for SARS-CoV-2. You need to fine-tune yourself on the relevant literature on {topic} to improve your ability to design SARS-CoV-2 nanobody binders. Please write out a series of five distinct search queries that you want to run to find relevant scientific papers on {topic}. Include both queries about {topic} generally as well as queries about how {topic} relates to designing nanobody binders for SARS-CoV-2. Please provide the queries in Python syntax as a list of double-quoted strings.",
            agenda_questions=(
                f"What are the queries that you want to perform to identify the relevant literature on {topic} (as a list of double-quoted strings in Python syntax)?",),
            save_dir=discussions_phase_to_dir["finetuning"],
            save_name=f"{topic.replace(' ', '_')}_queries",
            temperature=CONSISTENT_TEMPERATURE,
            model=model,
        ) for topic, agent in topic_to_agent.items()
    ])

In [None]:
# Agent fine-tuning queries
query_list_pattern = re.compile(r'\[\s*(".*?"\s*(,\s*".*?"\s*)*)?,?\s*\]')

topic_to_queries = {}

for topic, agent in topic_to_agent.items():
    # Get query path for topic
    query_path = discussions_phase_to_dir["finetuning"] / f"{topic.replace(' ', '_')}_queries.json"

    # Load query discussion
    with open(query_path) as f:
        query_discussion = json.load(f)

    # Extract queries
    query_message = query_discussion[-1]["message"]
    pattern_result = query_list_pattern.search(query_message)

    # Check if pattern is matched
    if pattern_result is None:
        print(f"No queries found for {query_path}")
        continue

    # Extract queries
    queries = json.loads(pattern_result.group())
    topic_to_queries[topic] = queries

print(topic_to_queries)

In [None]:
# Agent fine-tuning papers
for topic, agent in topic_to_agent.items():
    with concurrent.futures.ThreadPoolExecutor() as executor:
        concurrent.futures.wait([
            executor.submit(
                run_meeting,
                meeting_type="individual",
                team_member=agent,
                agenda=f"{background_prompt} {nanobody_prompt} You are responsible for understanding the topic \"{topic}\" in the context of designing nanobody binders for SARS-CoV-2. You need to fine-tune yourself on the relevant literature on {topic} to improve your ability to design SARS-CoV-2 nanobody binders. Please use PubMed Central and search for relevant papers on {topic} using the query \"{query}\" and request 100 articles with abstracts only. Read all of the abstracts and based on each abstract individually, decide whether you want to fine-tune yourself on the full text of that paper. Include as many papers as possible, but only include papers that are directly relevant to {topic}. Please provide the PMCIDs and titles of all the papers that you wish to fine-tune yourself on as a Python dictionary mapping PMCID as a double-quoted string to title as a double-quoted string.",
                agenda_questions=(
                    "What are the PMCIDs and titles of the papers you wish to fine-tune yourself on (as a Python dictionary mapping PMCID as a double-quoted string to title as double-quoted string)?",),
                save_dir=discussions_phase_to_dir["finetuning"],
                save_name=f"{topic.replace(' ', '_')}_papers_{query_num + 1}",
                temperature=CONSISTENT_TEMPERATURE,
                model=model,
                pubmed_search=True,
            ) for query_num, query in enumerate(topic_to_queries[topic]) if not (discussions_phase_to_dir["finetuning"] / f"{topic.replace(' ', '_')}_papers_{query_num + 1}.json").exists()
        ])

In [None]:
# Download papers from PubMed Central for fine-tuning
pmcid_to_title_pattern = re.compile(r'\{\s*(".*?"\s*:\s*".*?"\s*(,\s*".*?"\s*:\s*".*?"\s*)*)?\}')

for topic, agent in topic_to_agent.items():
    pmcids, titles = set(), set()

    # Get all paper paths for a topic
    paper_paths = sorted(discussions_phase_to_dir["finetuning"].glob(f"{topic.replace(' ', '_')}_papers_*.json"))

    # Check if all papers results are present
    if len(paper_paths) != 5:
        print(f"Missing papers for {topic}")
        continue

    # Extract PMC IDs and titles from each papers file
    for paper_path in paper_paths:
        # Load paper discussion
        with open(paper_path) as f:
            paper_discussion = json.load(f)

        # Extract PMC IDs and titles dictionary
        paper_message = paper_discussion[1]["message"]
        pattern_result = pmcid_to_title_pattern.search(paper_message)

        # Check if pattern is matched
        if pattern_result is None:
            print(f"No papers found for {paper_path}")
            continue

        # Extract PMC IDs and titles dictionary
        pmcid_to_title = json.loads(pattern_result.group())

        # Add PMC IDs and titles to set, avoiding duplicates by title
        for pmcid, title in pmcid_to_title.items():
            title = title.lower()

            if title not in titles:
                pmcids.add(pmcid)
                titles.add(title)

    # Download papers by PMC ID and format for fine-tuning
    training_data = []
    for pmcid in tqdm(pmcids):
        title, content = get_pubmed_central_article(pmcid=pmcid)

        if title is None:
            continue

        training_data += [
            {"messages": [{"role": "system", "content": agent.prompt},
                          {"role": "user", "content": ""},
                          {"role": "assistant", "content": paragraph}]}
            for paragraph in [title] + content
        ]

    print(f"Number of papers for {topic}: {len(pmcids):,}")
    print(f"Number of examples for {topic}: {len(training_data):,}")

    # Save training data in jsonl format
    with open(discussions_phase_to_dir["finetuning"] / f"{topic.replace(' ', '_')}_training_data.jsonl", "w") as f:
        f.write("\n".join(json.dumps(example) for example in training_data))

In [None]:
# Upload fine-tuning data
topic_to_id = {}

client = OpenAI()

for topic in topic_to_agent:
    path = discussions_phase_to_dir["finetuning"] / f"{topic.replace(' ', '_')}_training_data.jsonl"

    file_object = client.files.create(
        file=open(path, "rb"),
        purpose="fine-tune"
    )

    topic_to_id[topic] = file_object.id

# Save topic file IDs
with open(discussions_phase_to_dir["finetuning"] / "topic_to_id.json", "w") as f:
    json.dump(topic_to_id, f)

In [None]:
# Load topic file IDs
with open(discussions_phase_to_dir["finetuning"] / "topic_to_id.json") as f:
    topic_to_id = json.load(f)

topics = sorted(topic_to_id)

In [None]:
# Launch fine-tuning jobs
# for topic, file_id in topic_to_id.items():
#     client.fine_tuning.jobs.create(
#         training_file=file_id,
#         model=finetuning_model,
#     )

client.fine_tuning.jobs.create(
    training_file=topic_to_id[topics[3]],
    model=finetuning_base_model,  # TODO: swap for GPT-4o, not mini
)

In [None]:
list(client.fine_tuning.jobs.list())[0].status

In [None]:
# Set up topic to fine-tuned model mapping
topic_to_model = {}

In [None]:
finetuned_model = finetuning_base_model
# finetuned_model = "ft:gpt-4o-mini-2024-07-18:personal::AlQbTW9G"

agent = machine_learning_specialist
query = "What are the three algorithms of the DeepRank-GNN-esm model?"

assistant = client.beta.assistants.create(name=agent.title, instructions=agent.prompt, model=finetuned_model)
thread = client.beta.threads.create()
client.beta.threads.messages.create(thread_id=thread.id, role="user",content=query)
run = client.beta.threads.runs.create_and_poll(
    thread_id=thread.id,
    assistant_id=assistant.id,
    model=finetuned_model,
    temperature=CONSISTENT_TEMPERATURE,
)
messages = get_messages(client=client, thread_id=thread.id)

print(messages[-1]["content"][0]["text"]["value"])