In [14]:
import asyncio
import concurrent.futures
import json
import re

from openai import AsyncOpenAI, OpenAI
from tqdm import tqdm

from virtual_lab.agent import Agent
from virtual_lab.constants import CONSISTENT_TEMPERATURE
from virtual_lab.run_meeting import run_meeting
from virtual_lab.utils import async_get_messages, count_tokens, get_pubmed_central_article

from nanobody_constants import (
    background_prompt,
    nanobody_prompt,
    discussions_phase_to_dir,
    model,
    finetuning_base_model,
    immunologist,
    machine_learning_specialist,
    computational_biologist,
)

In [2]:
# Constants
finetuning_dir = discussions_phase_to_dir["finetuning"]
papers_dir = finetuning_dir / "papers"
summaries_dir = finetuning_dir / "summaries"

papers_dir.mkdir(parents=True, exist_ok=True)
summaries_dir.mkdir(parents=True, exist_ok=True)

client = OpenAI()
async_client = AsyncOpenAI()
num_concurrent = 50

In [3]:
# Topic to agent mapping
topic_to_agent = {
    "nanobodies": immunologist,
    "SARS-CoV-2 spike protein": immunologist,
    "SARS-CoV-2 variants KP.3 and JN.1": immunologist,
    "ESM": machine_learning_specialist,
    "AlphaFold-Multimer": computational_biologist,
    "Rosetta": computational_biologist,
}

In [None]:
# Agent fine-tuning queries
with concurrent.futures.ThreadPoolExecutor() as executor:
    concurrent.futures.wait([
        executor.submit(
            run_meeting,
            meeting_type="individual",
            team_member=agent,
            agenda=f"{background_prompt} {nanobody_prompt} You are responsible for understanding the topic \"{topic}\" in the context of designing nanobody binders for SARS-CoV-2. You need to fine-tune yourself on the relevant literature on {topic} to improve your ability to design SARS-CoV-2 nanobody binders. Please write out a series of five distinct search queries that you want to run to find relevant scientific papers on {topic}. Include both queries about {topic} generally as well as queries about how {topic} relates to designing nanobody binders for SARS-CoV-2. Please provide the queries in Python syntax as a list of double-quoted strings.",
            agenda_questions=(
                f"What are the queries that you want to perform to identify the relevant literature on {topic} (as a list of double-quoted strings in Python syntax)?",),
            save_dir=finetuning_dir,
            save_name=f"{topic.replace(' ', '_')}_queries",
            temperature=CONSISTENT_TEMPERATURE,
            model=model,
        ) for topic, agent in topic_to_agent.items()
    ])

In [None]:
# Agent fine-tuning queries
query_list_pattern = re.compile(r'\[\s*(".*?"\s*(,\s*".*?"\s*)*)?,?\s*\]')

topic_to_queries = {}

for topic, agent in topic_to_agent.items():
    # Get query path for topic
    query_path = finetuning_dir / f"{topic.replace(' ', '_')}_queries.json"

    # Load query discussion
    with open(query_path) as f:
        query_discussion = json.load(f)

    # Extract queries
    query_message = query_discussion[-1]["message"]
    pattern_result = query_list_pattern.search(query_message)

    # Check if pattern is matched
    if pattern_result is None:
        print(f"No queries found for {query_path}")
        continue

    # Extract queries
    queries = json.loads(pattern_result.group())
    topic_to_queries[topic] = queries

In [None]:
# Agent fine-tuning papers
for topic, agent in topic_to_agent.items():
    with concurrent.futures.ThreadPoolExecutor() as executor:
        concurrent.futures.wait([
            executor.submit(
                run_meeting,
                meeting_type="individual",
                team_member=agent,
                agenda=f"{background_prompt} {nanobody_prompt} You are responsible for understanding the topic \"{topic}\" in the context of designing nanobody binders for SARS-CoV-2. You need to fine-tune yourself on the relevant literature on {topic} to improve your ability to design SARS-CoV-2 nanobody binders. Please use PubMed Central and search for relevant papers on {topic} using the query \"{query}\" and request 100 articles with abstracts only. Read all of the abstracts and based on each abstract individually, decide whether you want to fine-tune yourself on the full text of that paper. Include as many papers as possible, but only include papers that are directly relevant to {topic}. Please provide the PMCIDs and titles of all the papers that you wish to fine-tune yourself on as a Python dictionary mapping PMCID as a double-quoted string to title as a double-quoted string.",
                agenda_questions=(
                    "What are the PMCIDs and titles of the papers you wish to fine-tune yourself on (as a Python dictionary mapping PMCID as a double-quoted string to title as double-quoted string)?",),
                save_dir=finetuning_dir,
                save_name=f"{topic.replace(' ', '_')}_papers_{query_num + 1}",
                temperature=CONSISTENT_TEMPERATURE,
                model=model,
                pubmed_search=True,
            ) for query_num, query in enumerate(topic_to_queries[topic]) if not (discussions_phase_to_dir[
                                                                                     "finetuning"] / f"{topic.replace(' ', '_')}_papers_{query_num + 1}.json").exists()
        ])

In [None]:
# Extract selected papers for fine-tuning
pmcid_to_title_pattern = re.compile(r'\{\s*(".*?"\s*:\s*".*?"\s*(,\s*".*?"\s*:\s*".*?"\s*)*)?\}')

for topic, agent in topic_to_agent.items():
    # Set up title to PMC ID dictionary
    title_to_pmcid = {}
    titles_lower, pmcids = set(), set()
    topic_name = topic.replace(' ', '_')

    # Get all paper paths for a topic
    paper_paths = sorted(finetuning_dir.glob(f"{topic_name}_papers_*.json"))

    # Check if all papers results are present
    if len(paper_paths) != 5:
        print(f"Missing papers for {topic}")
        continue

    # Extract PMC IDs and titles from each papers file
    for paper_path in paper_paths:
        # Load paper discussion
        with open(paper_path) as f:
            paper_discussion = json.load(f)

        # Extract PMC IDs and titles dictionary
        paper_message = paper_discussion[1]["message"]
        pattern_result = pmcid_to_title_pattern.search(paper_message)

        # Check if pattern is matched
        if pattern_result is None:
            print(f"No papers found for {paper_path}")
            continue

        # Extract PMC IDs and titles dictionary
        pmcid_to_title = json.loads(pattern_result.group())

        # Add PMC IDs and titles to dictionary, avoiding duplicates
        for pmcid, title in pmcid_to_title.items():
            # Replace en dash and em dash with a hyphen and convert to lowercase
            title = title.replace("–", "-").replace("—", "-")
            title_lower = title.lower()

            if title_lower not in titles_lower and pmcid not in pmcids:
                title_to_pmcid[title] = pmcid
                titles_lower.add(title_lower)
                pmcids.add(pmcid)

    print(f"Number of papers found for {topic}: {len(title_to_pmcid):,}")

    # Save title to PMC ID dictionary
    with open(finetuning_dir / f"{topic_name}_title_to_pmcid.json", "w") as f:
        json.dump(title_to_pmcid, f, indent=4, sort_keys=True)

In [None]:
# Get all PMCIDs
pmcids = set()

for topic in topic_to_agent:
    topic_name = topic.replace(' ', '_')

    with open(finetuning_dir / f"{topic_name}_title_to_pmcid.json") as f:
        title_to_pmcid = json.load(f)

    pmcids.update(title_to_pmcid.values())

print(f"Number of unique PMCIDs: {len(pmcids):,}")

In [None]:
# Download papers from PubMed Central
paper_count = 0

for pmcid in tqdm(sorted(pmcids)):
    title, content = get_pubmed_central_article(pmcid=pmcid)

    if title is None:
        continue

    paper_count += 1

    # Save paper
    with open(papers_dir / f"{pmcid}.json", "w") as f:
        json.dump({"title": title, "content": content}, f, indent=4, sort_keys=True)

print(f"Number of papers downloaded: {paper_count:,}")

In [4]:
%autoawait asyncio

In [5]:
async def summarize_paper(semaphore: asyncio.Semaphore, agent: Agent, topic: str, pmcid: str, title: str,
                          content: list[str]) -> tuple[str, str, str]:
    """Summarize a paper using the model.

    :param semaphore: Semaphore to limit the number of concurrent requests.
    :param agent: Agent to use for summarization.
    :param topic: Topic of interest.
    :param pmcid: PMC ID of the paper.
    :param title: Title of the paper.
    :param content: Content of the paper.
    :return: Tuple of PMC ID, title, and summary of the paper.
    """
    # Set up query with paper
    query = "\n\n".join([
                            f"Please summarize in extreme detail the following paper titled \"{title}\". Please focus in particular on summarizing key insights about the topic \"{topic}\" in relation to designing SARS-CoV-2 nanobody binders."] + content)

    # Run query to get summary
    async with semaphore:
        assistant = await async_client.beta.assistants.create(name=agent.title, instructions=agent.prompt, model=model)
        thread = await async_client.beta.threads.create()
        await async_client.beta.threads.messages.create(thread_id=thread.id, role="user", content=query)
        await async_client.beta.threads.runs.create_and_poll(
            thread_id=thread.id,
            assistant_id=assistant.id,
            model=model,
            temperature=CONSISTENT_TEMPERATURE,
        )
        messages = await async_get_messages(client=async_client, thread_id=thread.id)
        summary = messages[-1]["content"][0]["text"]["value"]

    return pmcid, title, summary

In [10]:
# Create agent summaries of papers for fine-tuning
for topic, agent in topic_to_agent.items():
    topic_name = topic.replace(' ', '_')

    # Create save directory
    topic_summary_dir = summaries_dir / topic_name
    topic_summary_dir.mkdir(parents=True, exist_ok=True)

    # Load title to PMC ID dictionary
    with open(finetuning_dir / f"{topic_name}_title_to_pmcid.json") as f:
        title_to_pmcid: dict[str, str] = json.load(f)

    # Get unique PMC IDs
    pmcids = sorted(set((title_to_pmcid.values())))

    # Load papers
    pmcid_to_paper = {}
    for pmcid in pmcids:
        paper_path = papers_dir / f"{pmcid}.json"

        if paper_path.exists():
            with open(paper_path) as f:
                paper: dict[str, str | list[str]] = json.load(f)
                pmcid_to_paper[pmcid] = paper

    print(f"Number of papers loaded for {topic}: {len(pmcid_to_paper):,}")

    # Set up semaphore with the number of concurrent requests
    semaphore = asyncio.Semaphore(num_concurrent)

    # Create tasks for each paper
    tasks = [
        asyncio.create_task(
            summarize_paper(semaphore=semaphore, agent=agent, topic=topic, pmcid=pmcid, title=paper["title"],
                            content=paper["content"]))
        for pmcid, paper in pmcid_to_paper.items()]

    # Run agent summary of each paper
    results = [(await task) for task in tqdm(asyncio.as_completed(tasks), total=len(tasks))]

    # Save summaries
    for pmcid, title, summary in results:
        with open(topic_summary_dir / f"{pmcid}.json", "w") as f:
            json.dump({"pmcid": pmcid, "title": title, "summary": summary}, f, indent=4, sort_keys=True)

Number of papers loaded for nanobodies: 261


 12%|█▏        | 32/261 [00:21<02:33,  1.49it/s] 


BadRequestError: Error code: 400 - {'error': {'message': "Invalid 'content': string too long. Expected a string with maximum length 256000, but got a string with length 1671206 instead.", 'type': 'invalid_request_error', 'param': 'content', 'code': 'string_above_max_length'}}

In [15]:
# Convert summaries to training data format
for topic, agent in topic_to_agent.items():
    topic_name = topic.replace(' ', '_')

    # Get summary paths
    topic_summary_dir = summaries_dir / topic_name
    summary_paths = sorted(topic_summary_dir.glob("*.json"))

    # Convert summaries to training data format
    training_data = []

    for summary_path in summary_paths:
        # Load paper summary data
        with open(summary_path) as f:
            summary_data = json.load(f)

        # Extract title and summary
        title, summary = summary_data["title"], summary_data["summary"]

        # Add example to training data
        training_data.append(
            {"messages": [{"role": "system", "content": agent.prompt},
                          {"role": "user",
                           "content": f"Please tell me about the paper \"{title}\" and its insights into \"{topic}\" in relation to designing SARS-CoV-2 nanobody binders."},
                          {"role": "assistant", "content": summary}]}
        )

    # Count tokens
    token_count = [sum(count_tokens(message["content"]) for message in data["messages"]) for data in training_data]

    # TODO: estimate fine-tuning pricing

    # Print stats
    print(f"Number of paper examples for {topic}: {len(training_data):,}")
    print(f"Token count for {topic}: {sum(token_count):,}")

    # Save training data in jsonl format
    with open(finetuning_dir / f"{topic_name}_training_data.jsonl", "w") as f:
        f.write("\n".join(json.dumps(example) for example in training_data))

Number of paper examples for ESM: 34
Token count for ESM: 27,518


In [None]:
# Upload fine-tuning data
topic_to_id = {}

for topic in topic_to_agent:
    path = finetuning_dir / f"{topic.replace(' ', '_')}_training_data.jsonl"

    file_object = client.files.create(
        file=open(path, "rb"),
        purpose="fine-tune"
    )

    topic_to_id[topic] = file_object.id

# Save topic file IDs
with open(finetuning_dir / "topic_to_id.json", "w") as f:
    json.dump(topic_to_id, f)

In [None]:
# Load topic file IDs
with open(finetuning_dir / "topic_to_id.json") as f:
    topic_to_id = json.load(f)

topics = sorted(topic_to_id)

In [None]:
# Launch fine-tuning jobs
# for topic, file_id in topic_to_id.items():
#     client.fine_tuning.jobs.create(
#         training_file=file_id,
#         model=finetuning_model,
#     )

topic = topics[5]

client.fine_tuning.jobs.create(
    training_file=topic_to_id[topic],
    model=finetuning_base_model,  # TODO: swap for GPT-4o, not mini
    suffix=topic.replace(" ", "_"),
)

In [None]:
print(list(client.fine_tuning.jobs.list())[1].fine_tuned_model)

In [None]:
# Set up topic to fine-tuned model mapping
topic_to_model = {
    "AlphaFold-Multimer": "ft:gpt-4o-mini-2024-07-18:personal:alphafold-multimer:AmtqgHON",
    "ESM": "ft:gpt-4o-mini-2024-07-18:personal:esm:AmtjuDox",
    "Rosetta": "ft:gpt-4o-mini-2024-07-18:personal:rosetta:Amtuos8C",
    "SARS-CoV-2 spike protein": "ft:gpt-4o-mini-2024-07-18:personal:sars-cov-2-spike-protein:AmuRh1c1",
    "SARS-CoV-2 variants KP.3 and JN.1": "ft:gpt-4o-mini-2024-07-18:personal:sars-cov-2-variants-kp-3-and-jn-1:AmyELRn8",
    "nanobodies": "ft:gpt-4o-mini-2024-07-18:personal:nanobodies:AmyVyYww",
}

In [None]:
topic = "SARS-CoV-2 variants KP.3 and JN.1"
agent = topic_to_agent[topic]
query = "How are the JN.1 and KP.3 variants of SARS-CoV-2 related to each other?"

for selected_model, selected_model_name in [(finetuning_base_model, "base model"),
                                            (topic_to_model[topic], "fine-tuned model")]:
    print(f"Running query \"{query}\" with {selected_model_name} for {topic}.\n")

    assistant = client.beta.assistants.create(name=agent.title, instructions=agent.prompt, model=selected_model)
    thread = client.beta.threads.create()
    client.beta.threads.messages.create(thread_id=thread.id, role="user", content=query)
    run = client.beta.threads.runs.create_and_poll(
        thread_id=thread.id,
        assistant_id=assistant.id,
        model=selected_model,
        temperature=CONSISTENT_TEMPERATURE,
    )

    print(run.status)

    # if not run.status == "complete":
    #     print("Query failed to complete.")
    #     continue

    messages = get_messages(client=client, thread_id=thread.id)

    print(messages[-1]["content"][0]["text"]["value"])
    print()