In [1]:
import concurrent.futures
import json
import re

from openai import OpenAI
from tqdm import tqdm

from virtual_lab.constants import CONSISTENT_TEMPERATURE
from virtual_lab.run_meeting import run_meeting
from virtual_lab.utils import get_messages, get_pubmed_central_article

from nanobody_constants import (
    background_prompt,
    nanobody_prompt,
    discussions_phase_to_dir,
    model,
    finetuning_base_model,
    immunologist,
    machine_learning_specialist,
    computational_biologist,
)

In [2]:
# Topic to agent mapping
topic_to_agent = {
    "nanobodies": immunologist,
    "SARS-CoV-2 spike protein": immunologist,
    "SARS-CoV-2 variants KP.3 and JN.1": immunologist,
    "ESM": machine_learning_specialist,
    "AlphaFold-Multimer": computational_biologist,
    "Rosetta": computational_biologist,
}

In [None]:
# Agent fine-tuning queries
with concurrent.futures.ThreadPoolExecutor() as executor:
    concurrent.futures.wait([
        executor.submit(
            run_meeting,
            meeting_type="individual",
            team_member=agent,
            agenda=f"{background_prompt} {nanobody_prompt} You are responsible for understanding the topic \"{topic}\" in the context of designing nanobody binders for SARS-CoV-2. You need to fine-tune yourself on the relevant literature on {topic} to improve your ability to design SARS-CoV-2 nanobody binders. Please write out a series of five distinct search queries that you want to run to find relevant scientific papers on {topic}. Include both queries about {topic} generally as well as queries about how {topic} relates to designing nanobody binders for SARS-CoV-2. Please provide the queries in Python syntax as a list of double-quoted strings.",
            agenda_questions=(
                f"What are the queries that you want to perform to identify the relevant literature on {topic} (as a list of double-quoted strings in Python syntax)?",),
            save_dir=discussions_phase_to_dir["finetuning"],
            save_name=f"{topic.replace(' ', '_')}_queries",
            temperature=CONSISTENT_TEMPERATURE,
            model=model,
        ) for topic, agent in topic_to_agent.items()
    ])

In [3]:
# Agent fine-tuning queries
query_list_pattern = re.compile(r'\[\s*(".*?"\s*(,\s*".*?"\s*)*)?,?\s*\]')

topic_to_queries = {}

for topic, agent in topic_to_agent.items():
    # Get query path for topic
    query_path = discussions_phase_to_dir["finetuning"] / f"{topic.replace(' ', '_')}_queries.json"

    # Load query discussion
    with open(query_path) as f:
        query_discussion = json.load(f)

    # Extract queries
    query_message = query_discussion[-1]["message"]
    pattern_result = query_list_pattern.search(query_message)

    # Check if pattern is matched
    if pattern_result is None:
        print(f"No queries found for {query_path}")
        continue

    # Extract queries
    queries = json.loads(pattern_result.group())
    topic_to_queries[topic] = queries

In [None]:
# Agent fine-tuning papers
for topic, agent in topic_to_agent.items():
    with concurrent.futures.ThreadPoolExecutor() as executor:
        concurrent.futures.wait([
            executor.submit(
                run_meeting,
                meeting_type="individual",
                team_member=agent,
                agenda=f"{background_prompt} {nanobody_prompt} You are responsible for understanding the topic \"{topic}\" in the context of designing nanobody binders for SARS-CoV-2. You need to fine-tune yourself on the relevant literature on {topic} to improve your ability to design SARS-CoV-2 nanobody binders. Please use PubMed Central and search for relevant papers on {topic} using the query \"{query}\" and request 100 articles with abstracts only. Read all of the abstracts and based on each abstract individually, decide whether you want to fine-tune yourself on the full text of that paper. Include as many papers as possible, but only include papers that are directly relevant to {topic}. Please provide the PMCIDs and titles of all the papers that you wish to fine-tune yourself on as a Python dictionary mapping PMCID as a double-quoted string to title as a double-quoted string.",
                agenda_questions=(
                    "What are the PMCIDs and titles of the papers you wish to fine-tune yourself on (as a Python dictionary mapping PMCID as a double-quoted string to title as double-quoted string)?",),
                save_dir=discussions_phase_to_dir["finetuning"],
                save_name=f"{topic.replace(' ', '_')}_papers_{query_num + 1}",
                temperature=CONSISTENT_TEMPERATURE,
                model=model,
                pubmed_search=True,
            ) for query_num, query in enumerate(topic_to_queries[topic]) if not (discussions_phase_to_dir[
                                                                                     "finetuning"] / f"{topic.replace(' ', '_')}_papers_{query_num + 1}.json").exists()
        ])

In [10]:
# Extract selected papers for fine-tuning
pmcid_to_title_pattern = re.compile(r'\{\s*(".*?"\s*:\s*".*?"\s*(,\s*".*?"\s*:\s*".*?"\s*)*)?\}')

for topic, agent in topic_to_agent.items():
    # Set up title to PMC ID dictionary
    title_to_pmcid = {}
    titles_lower, pmcids = set(), set()
    topic_name = topic.replace(' ', '_')

    # Get all paper paths for a topic
    paper_paths = sorted(discussions_phase_to_dir["finetuning"].glob(f"{topic_name}_papers_*.json"))

    # Check if all papers results are present
    if len(paper_paths) != 5:
        print(f"Missing papers for {topic}")
        continue

    # Extract PMC IDs and titles from each papers file
    for paper_path in paper_paths:
        # Load paper discussion
        with open(paper_path) as f:
            paper_discussion = json.load(f)

        # Extract PMC IDs and titles dictionary
        paper_message = paper_discussion[1]["message"]
        pattern_result = pmcid_to_title_pattern.search(paper_message)

        # Check if pattern is matched
        if pattern_result is None:
            print(f"No papers found for {paper_path}")
            continue

        # Extract PMC IDs and titles dictionary
        pmcid_to_title = json.loads(pattern_result.group())

        # Add PMC IDs and titles to dictionary, avoiding duplicates
        for pmcid, title in pmcid_to_title.items():
            # Replace en dash and em dash with a hyphen and convert to lowercase
            title = title.replace("–", "-").replace("—", "-")
            title_lower = title.lower()

            if title_lower not in titles_lower and pmcid not in pmcids:
                title_to_pmcid[title] = pmcid
                titles_lower.add(title_lower)
                pmcids.add(pmcid)

    print(f"Number of papers for {topic}: {len(title_to_pmcid):,}")

    # Save title to PMC ID dictionary
    with open(discussions_phase_to_dir["finetuning"] / f"{topic_name}_title_to_pmcid.json", "w") as f:
        json.dump(title_to_pmcid, f, indent=4, sort_keys=True)

Number of papers for nanobodies: 261
Number of papers for SARS-CoV-2 spike protein: 331
No papers found for discussions/finetuning/SARS-CoV-2_variants_KP.3_and_JN.1_papers_5.json
Number of papers for SARS-CoV-2 variants KP.3 and JN.1: 24
Number of papers for ESM: 34
Number of papers for AlphaFold-Multimer: 112
Number of papers for Rosetta: 113


In [11]:
# Download papers from PubMed Central for fine-tuning
for topic, agent in topic_to_agent.items():
    topic_name = topic.replace(' ', '_')

    # Load title to PMC ID dictionary
    with open(discussions_phase_to_dir["finetuning"] / f"{topic_name}_title_to_pmcid.json") as f:
        title_to_pmcid = json.load(f)

    # Get unique PMC IDs
    pmcids = sorted(set((title_to_pmcid.values())))

    # Download papers by PMC ID and format for fine-tuning
    training_data = []
    for pmcid in tqdm(pmcids):
        title, content = get_pubmed_central_article(pmcid=pmcid)

        if title is None:
            continue

        training_data += [
            {"messages": [{"role": "system", "content": agent.prompt},
                          {"role": "user", "content": ""},
                          {"role": "assistant", "content": paragraph}]}
            for paragraph in [title] + content
        ]

    print(f"Number of full text papers for {topic}: {len(pmcids):,}")
    print(f"Number of paragraph examples for {topic}: {len(training_data):,}")

    # Save training data in jsonl format
    with open(discussions_phase_to_dir["finetuning"] / f"{topic.replace(' ', '_')}_training_data.jsonl", "w") as f:
        f.write("\n".join(json.dumps(example) for example in training_data))

100%|██████████| 261/261 [02:07<00:00,  2.05it/s]


Number of full text papers for nanobodies: 261
Number of paragraph examples for nanobodies: 36,715


100%|██████████| 331/331 [02:38<00:00,  2.09it/s]


Number of full text papers for SARS-CoV-2 spike protein: 331
Number of paragraph examples for SARS-CoV-2 spike protein: 46,472


100%|██████████| 24/24 [00:10<00:00,  2.23it/s]


Number of full text papers for SARS-CoV-2 variants KP.3 and JN.1: 24
Number of paragraph examples for SARS-CoV-2 variants KP.3 and JN.1: 926


100%|██████████| 34/34 [00:15<00:00,  2.26it/s]


Number of full text papers for ESM: 34
Number of paragraph examples for ESM: 1,853


100%|██████████| 112/112 [00:49<00:00,  2.28it/s]


Number of full text papers for AlphaFold-Multimer: 112
Number of paragraph examples for AlphaFold-Multimer: 6,478


100%|██████████| 113/113 [00:52<00:00,  2.17it/s]

Number of full text papers for Rosetta: 113
Number of paragraph examples for Rosetta: 5,890





In [12]:
client = OpenAI()

In [13]:
# Upload fine-tuning data
topic_to_id = {}

for topic in topic_to_agent:
    path = discussions_phase_to_dir["finetuning"] / f"{topic.replace(' ', '_')}_training_data.jsonl"

    file_object = client.files.create(
        file=open(path, "rb"),
        purpose="fine-tune"
    )

    topic_to_id[topic] = file_object.id

# Save topic file IDs
with open(discussions_phase_to_dir["finetuning"] / "topic_to_id.json", "w") as f:
    json.dump(topic_to_id, f)

In [14]:
# Load topic file IDs
with open(discussions_phase_to_dir["finetuning"] / "topic_to_id.json") as f:
    topic_to_id = json.load(f)

topics = sorted(topic_to_id)

In [23]:
# Launch fine-tuning jobs
# for topic, file_id in topic_to_id.items():
#     client.fine_tuning.jobs.create(
#         training_file=file_id,
#         model=finetuning_model,
#     )

topic = topics[3]

client.fine_tuning.jobs.create(
    training_file=topic_to_id[topic],
    model=finetuning_base_model,  # TODO: swap for GPT-4o, not mini
    suffix=topic.replace(" ", "_"),
)

FineTuningJob(id='ftjob-cM8IxrOQCAnKnfZwsZG67NVn', created_at=1736215504, error=Error(code=None, message=None, param=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs='auto', batch_size='auto', learning_rate_multiplier='auto'), model='gpt-4o-mini-2024-07-18', object='fine_tuning.job', organization_id='org-JdFlu7kGMrLJr0BRViNsU6SO', result_files=[], seed=2049714135, status='validating_files', trained_tokens=None, training_file='file-HxGMGMpgR4RtrSJrAv3dPL', validation_file=None, estimated_finish=None, integrations=[], user_provided_suffix='SARS-CoV-2_spike_protein', method={'type': 'supervised', 'supervised': {'hyperparameters': {'batch_size': 'auto', 'learning_rate_multiplier': 'auto', 'n_epochs': 'auto'}}})

In [24]:
list(client.fine_tuning.jobs.list())[0].status

'validating_files'

In [36]:
# Set up topic to fine-tuned model mapping
# TODO: update model IDs
topic_to_model = {
    "AlphaFold-Multimer": "ft:gpt-4o-mini-2024-07-18:personal::AlsW3uTg",
    "ESM": "ft:gpt-4o-mini-2024-07-18:personal::AlrtT5SS",
    "Rosetta": "ft:gpt-4o-mini-2024-07-18:personal::AlsI0vLG",
    "SARS-CoV-2 spike protein": "ft:gpt-4o-mini-2024-07-18:personal::Am2Y4mU1",
    "SARS-CoV-2 variants KP.3 and JN.1": "ft:gpt-4o-mini-2024-07-18:personal::Am1FIJLR",
    "nanobodies": "ft:gpt-4o-mini-2024-07-18:personal::Am1xzqlG",
}

In [None]:
topic = "ESM"

selected_model = finetuning_base_model
# selected_model = topic_to_model[topic]
agent = topic_to_agent[topic]
query = "What are the three algorithms of the DeepRank-GNN-esm model?"

assistant = client.beta.assistants.create(name=agent.title, instructions=agent.prompt, model=selected_model)
thread = client.beta.threads.create()
client.beta.threads.messages.create(thread_id=thread.id, role="user", content=query)
run = client.beta.threads.runs.create_and_poll(
    thread_id=thread.id,
    assistant_id=assistant.id,
    model=selected_model,
    temperature=CONSISTENT_TEMPERATURE,
)
messages = get_messages(client=client, thread_id=thread.id)

print(messages[-1]["content"][0]["text"]["value"])