In [None]:
import concurrent.futures
import json
import re

from openai import OpenAI

from virtual_lab.constants import CONSISTENT_TEMPERATURE
from virtual_lab.run_meeting import run_meeting
from virtual_lab.utils import get_pubmed_central_article

from nanobody_constants import (
    background_prompt,
    nanobody_prompt,
    discussions_phase_to_dir,
    model,
    finetuning_base_model,
    immunologist,
    machine_learning_specialist,
    computational_biologist,
)

In [None]:
# Topic to agent mapping
topic_to_agent = {
    "nanobodies": immunologist,
    "SARS-CoV-2 spike protein": immunologist,
    "SARS-CoV-2 variants KP.3 and JN.1": immunologist,
    "ESM": machine_learning_specialist,
    "AlphaFold-Multimer": computational_biologist,
    "Rosetta": computational_biologist,
}

In [None]:
# Agent fine-tuning queries
with concurrent.futures.ThreadPoolExecutor() as executor:
    concurrent.futures.wait([
        executor.submit(
            run_meeting,
            meeting_type="individual",
            team_member=agent,
            agenda=f"{background_prompt} {nanobody_prompt} You are responsible for understanding the topic \"{topic}\" in the context of designing nanobody binders for SARS-CoV-2. You need to fine-tune yourself on the relevant literature on {topic} to improve your ability to design SARS-CoV-2 nanobody binders. Please write out a series of five distinct search queries that you want to run to find relevant scientific papers on {topic}. Include both queries about {topic} generally as well as queries about how {topic} relates to designing nanobody binders for SARS-CoV-2. Please provide the queries in Python syntax as a list of strings.",
            agenda_questions=(
                f"What are the queries that you want to perform to identify the relevant literature on {topic} (as a list of strings in Python syntax)?",),
            save_dir=discussions_phase_to_dir["finetuning"],
            save_name=f"{topic.replace(' ', '_')}_queries",
            temperature=CONSISTENT_TEMPERATURE,
            model=model,
        ) for topic, agent in topic_to_agent.items()
    ])

In [None]:
# Agent fine-tuning queries
topic_to_queries = {
    "nanobodies": [
        "nanobodies AND SARS-CoV-2 spike protein",
        "nanobody engineering AND cross-reactivity AND SARS-CoV-2 variants",
        "machine learning AND nanobody design AND SARS-CoV-2",
        "broad-spectrum nanobodies AND coronavirus",
        "nanobody optimization AND immune response AND SARS-CoV-2"
    ],
    "SARS-CoV-2 spike protein": [
        "SARS-CoV-2 spike protein structure and function",
        "SARS-CoV-2 spike protein variants and mutations",
        "nanobody binding to SARS-CoV-2 spike protein",
        "machine learning in antibody and nanobody design for SARS-CoV-2",
        "cross-reactivity of nanobodies with SARS-CoV-2 spike protein variants"
    ],
    "SARS-CoV-2 variants KP.3 and JN.1": [
        "SARS-CoV-2 variant KP.3 spike protein structure",
        "SARS-CoV-2 variant JN.1 spike protein structure",
        "nanobody design for SARS-CoV-2 variant KP.3",
        "nanobody design for SARS-CoV-2 variant JN.1",
        "cross-reactivity of nanobodies with SARS-CoV-2 variants KP.3 and JN.1"
    ],
    "ESM": [
        "Evolutionary Scale Modeling (ESM) in protein design",
        "ESM for antibody and nanobody development",
        "ESM SARS-CoV-2 spike protein binding",
        "Machine learning ESM nanobody design SARS-CoV-2",
        "ESM optimization of nanobody binding affinity SARS-CoV-2 variants"
    ],
    "AlphaFold-Multimer": [
        "AlphaFold-Multimer protein structure prediction",
        "AlphaFold-Multimer SARS-CoV-2 spike protein",
        "designing nanobody binders using AlphaFold-Multimer",
        "AlphaFold-Multimer nanobody interaction SARS-CoV-2",
        "machine learning AlphaFold-Multimer nanobody SARS-CoV-2"
    ],
    "Rosetta": [
        "Rosetta software protein design",
        "Rosetta nanobody design SARS-CoV-2",
        "Rosetta molecular modeling SARS-CoV-2 spike protein",
        "Rosetta antibody engineering SARS-CoV-2",
        "Rosetta computational biology nanobody optimization"
    ],
}

In [None]:
# Agent fine-tuning papers
for topic, agent in topic_to_agent.items():
    if topic not in topic_to_queries:
        continue
    with concurrent.futures.ThreadPoolExecutor() as executor:
        concurrent.futures.wait([
            executor.submit(
                run_meeting,
                meeting_type="individual",
                team_member=agent,
                agenda=f"{background_prompt} {nanobody_prompt} You are responsible for understanding the topic \"{topic}\" in the context of designing nanobody binders for SARS-CoV-2. You need to fine-tune yourself on the relevant literature on {topic} to improve your ability to design SARS-CoV-2 nanobody binders. Please use PubMed Central and search for relevant papers on {topic} using the query \"{query}\" and request 25 articles with abstracts only. Read all of the abstracts and based on each abstract individually, decide whether you want to fine-tune yourself on the full text of that paper. Include as many papers as possible, but only include papers that are directly relevant to {topic}. Please provide the PMCIDs and titles of all the papers that you wish to fine-tune yourself on as a Python dictionary mapping PMCID as a double-quoted string to title as a double-quoted string.",
                agenda_questions=(
                    "What are the PMCIDs and titles of the papers you wish to fine-tune yourself on (as a Python dictionary mapping PMCID as a double-quoted string to title as double-quoted string)?",),
                save_dir=discussions_phase_to_dir["finetuning"],
                save_name=f"{topic.replace(' ', '_')}_papers_{query_num + 1}",
                temperature=CONSISTENT_TEMPERATURE,
                model=model,
                pubmed_search=True,
            ) for query_num, query in enumerate(topic_to_queries[topic])
        ])

In [None]:
# Download papers from PubMed Central for fine-tuning
pmcid_to_title_pattern = re.compile(r'\{\s*(".*?"\s*:\s*".*?"\s*(,\s*".*?"\s*:\s*".*?"\s*)*)?\}')

for topic, agent in topic_to_agent.items():
    pmcids, titles = set(), set()

    # Get all paper paths for a topic
    paper_paths = sorted(discussions_phase_to_dir["finetuning"].glob(f"{topic.replace(' ', '_')}_papers_*.json"))

    # Check if all papers results are present
    if len(paper_paths) != 5:
        print(f"Missing papers for {topic}")
        continue

    # Extract PMC IDs and titles from each papers file
    for paper_path in paper_paths:
        # Load paper discussion
        with open(paper_path) as f:
            paper_discussion = json.load(f)

        # Extract PMC IDs and titles dictionary
        paper_message = paper_discussion[-1]["message"]
        pattern_result = pmcid_to_title_pattern.search(paper_message)

        # Check if pattern matched
        if pattern_result is None:
            print(f"No papers found for {paper_path}")
            continue

        # Extract PMC IDs and titles dictionary
        pmcid_to_title = json.loads(pattern_result.group())

        # Add PMC IDs and titles to set, avoiding duplicates by title
        for pmcid, title in pmcid_to_title.items():
            title = title.lower()

            if title not in titles:
                pmcids.add(pmcid)
                titles.add(title)

    # Download papers by PMC ID and format for fine-tuning
    training_data = [
        {"messages": [{"role": "system", "content": agent.prompt},
                      {"role": "user", "content": ""},
                      {"role": "assistant", "content": paragraph}]}
        for pmcid in sorted(pmcids) for paragraph in get_pubmed_central_article(pmcid=pmcid)[1]
    ]

    print(f"Number of examples for {topic}: {len(training_data):,}")

    # Save training data in jsonl format
    with open(discussions_phase_to_dir["finetuning"] / f"{topic.replace(' ', '_')}_training_data.jsonl", "w") as f:
        f.write("\n".join(json.dumps(example) for example in training_data))

In [None]:
# Upload fine-tuning data
topic_to_id = {}

client = OpenAI()

for topic in topic_to_agent:
    path = discussions_phase_to_dir["finetuning"] / f"{topic.replace(' ', '_')}_training_data.jsonl"

    file_object = client.files.create(
        file=open(path, "rb"),
        purpose="fine-tune"
    )

    topic_to_id[topic] = file_object.id

# Save file IDs
with open(discussions_phase_to_dir["finetuning"] / "topic_to_id.json", "w") as f:
    json.dump(topic_to_id, f)
file_objects = client.files.list().data
print(f"Found {len(file_objects)} files")

In [None]:
# Launch fine-tuning runs
with open(discussions_phase_to_dir["finetuning"] / "topic_to_id.json") as f:
    topic_to_id = json.load(f)

# for topic, file_id in topic_to_id.items():
#     client.fine_tuning.jobs.create(
#         training_file=file_id,
#         model=finetuning_model,
#     )

client.fine_tuning.jobs.create(
    training_file=topic_to_id["ESM"],
    model=finetuning_base_model,
)

In [None]:
list(client.fine_tuning.jobs.list())[0].status

In [None]:
from virtual_lab.utils import get_messages

finetuned_model = "gpt-4o-mini-2024-07-18"
# finetuned_model = "ft:gpt-4o-mini-2024-07-18:personal::AlQbTW9G"
agent = client.beta.assistants.create(name=machine_learning_specialist.title,
                                      instructions=machine_learning_specialist.prompt, model=finetuned_model)
thread = client.beta.threads.create()
client.beta.threads.messages.create(thread_id=thread.id, role="user",
                                    content="What are the three algorithms of the DeepRank-GNN-esm model?")
run = client.beta.threads.runs.create_and_poll(
    thread_id=thread.id,
    assistant_id=agent.id,
    model=finetuned_model,
    temperature=CONSISTENT_TEMPERATURE,
)
messages = get_messages(client=client, thread_id=thread.id)
print(messages[-1]["content"][0]["text"]["value"])