# Goals

* Use map-reduce-subgraph framework with SRA tools agent

In [5]:
# import 
import os
import re
import time
from pprint import pprint 
from datetime import datetime, timedelta
from typing import Annotated, List, Dict, Tuple, Optional, Union, Any
import xml.etree.ElementTree as ET
from pydantic import BaseModel, Field
from langchain_core.tools import tool
from Bio import Entrez

# Tools

In [2]:
# set up Entrez
Entrez.email = "nick.youngblut@arcinstitute.org"

In [108]:
@tool 
def esearch(
    esearch_query: Annotated[str, "Entrez query string."],
    database: Annotated[str, "Database name ('sra' or 'gds')"]="sra",
    #organism: Annotated[str, "Organism name ('human' or 'mouse')"]="human",
    )-> Annotated[List[str], "IDs of database records"]:
    """
    Run an Entrez search query and return the IDs of the results.
    """
    # date range
    start_date = datetime.now() - timedelta(days=7)
    end_date = datetime.now()
    date_range = f"{start_date.strftime('%Y/%m/%d')}:{end_date.strftime('%Y/%m/%d')}[PDAT]"
    esearch_query += f" AND {date_range}"

    max_ids = 2  # DEBUG

    # query
    ids = []
    retstart = 0
    retmax = 50
    while True:
        try:
            search_handle = Entrez.esearch(
                db=database, 
                term=esearch_query, 
                retstart=retstart, 
                retmax=retmax
            )
            search_results = Entrez.read(search_handle)
            search_handle.close()
            ids.extend(search_results["IdList"])
            retstart += retmax
            time.sleep(0.5)
            if max_ids and len(ids) >= max_ids:
                break
            if retstart >= int(search_results['Count']):
                break
        except Exception as e:
            print(f"Error searching {database} with query: {esearch_query}: {str(e)}")
            break 
    # return IDs
    return ids[:max_ids]  # debug

#query = '("single cell RNA sequencing" OR "single cell RNA-seq")'
#IDs = esearch.invoke({"esearch_query" : query, "database" : "sra"})
#IDs

In [118]:
@tool 
def efetch(
    database: Annotated[str, "Database name ('sra' or 'gds')"],
    dataset_id: Annotated[str, "Entrez ID"],
    )-> Annotated[List[str], "eFetch results in xml format"]:
    """
    Run an Entrez efetch query and return the results.
    """

    time.sleep(0.5)

    # Fetch dataset record
    handle = Entrez.efetch(db=database, id=dataset_id, retmode="xml")
    record = handle.read()
    handle.close()

    try:
        record = record.decode("utf-8")
    except:
        pass

    return str(record)

#record = efetch.invoke({"database" : "sra", "dataset_id" : "35966237"})
#pprint(record)

# Graph

In [98]:
import operator
from typing import Annotated, Sequence, Tuple, Union
from typing_extensions import TypedDict
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.types import Send
from langgraph.graph import START, END, StateGraph
from pydantic import BaseModel, Field
from langgraph.prebuilt import create_react_agent, ToolNode

In [99]:
# set model
model = ChatOpenAI(model="gpt-4o-mini")

## Subgraph

In [138]:
class SubState(TypedDict):
    """
    Shared state of the agents in the subgraph
    """
    database: str
    ID: str
    record: str

In [139]:
def run_efetch(SubState):
    record = efetch.invoke({"database" : SubState["database"], "dataset_id" : SubState["ID"]})
    return {"record" : record}

In [141]:
#-- subgraph --#
subworkflow = StateGraph(SubState)

# nodes
subworkflow.add_node("efetch_node", run_efetch)

# edges
subworkflow.add_edge(START, "efetch_node")
subworkflow.add_edge("efetch_node", END)

# compile the graph
subgraph = subworkflow.compile()

## Graph

In [152]:
class TopState(TypedDict):
    """
    Shared state of the agents in the graph
    """
    database: str
    # esearch IDs
    IDs: Annotated[List[str], operator.add]
    # efetch records
    records: Annotated[List[str], operator.add]

In [None]:
def run_esearch(state: TopState):
    query = '("single cell RNA sequencing" OR "single cell RNA-seq")'
    IDs = esearch.invoke({"esearch_query" : query, "database" : state["database"]})
    return {"IDs" : IDs}

def invoke_subgraph(state: SubState):
    record = subgraph.invoke({"database" : state["database"], "ID" : state["ID"]})
    return {"records" : [record]}

# Here we define the logic to map out over the generated subjects
def continue_to_subgraphs(state: TopState):
    return [Send("invoke_subgraph", {"ID": x, "database" : state["database"]}) for x in state["IDs"]]

In [166]:
#-- graph --#
workflow = StateGraph(TopState)

# nodes
workflow.add_node("esearch_node", run_esearch)
workflow.add_node("invoke_subgraph", invoke_subgraph)

# edges
workflow.add_edge(START, "esearch_node")
workflow.add_conditional_edges("esearch_node", continue_to_subgraphs, ["invoke_subgraph"])
workflow.add_edge("invoke_subgraph", END)

# compile the graph
graph = workflow.compile()

In [167]:
from IPython.display import Image
#Image(graph.get_graph().draw_mermaid_png())

In [None]:
# Call the graph: here we call it to generate a list of jokes
for s in graph.stream({"database": "sra"}, subgraphs=True, config={"max_concurrency" : 2}):
    print(s)

((), {'esearch_node': {'IDs': ['36004814', '36004694']}})
{'ID': '36004814', 'database': 'sra'}
{'ID': '36004694', 'database': 'sra'}
(('invoke_subgraph:d81fd364-ab08-52dd-eb41-33ae58bc08fe',), {'efetch_node': {'record': '<?xml version="1.0" encoding="UTF-8"  ?>\n<EXPERIMENT_PACKAGE_SET>\n<EXPERIMENT_PACKAGE><EXPERIMENT accession="ERX13336124" alias="ena-EXPERIMENT-TAB-07-11-2024-13:04:12:932-215992" center_name="Klinikum rechts der Isar, II. Medizinische Klinik"><IDENTIFIERS><PRIMARY_ID>ERX13336124</PRIMARY_ID></IDENTIFIERS><TITLE>Illumina NovaSeq 6000 paired end sequencing</TITLE><STUDY_REF accession="ERP165048"><IDENTIFIERS><PRIMARY_ID>ERP165048</PRIMARY_ID><EXTERNAL_ID namespace="BioProject">PRJEB81204</EXTERNAL_ID></IDENTIFIERS></STUDY_REF><DESIGN><DESIGN_DESCRIPTION/><SAMPLE_DESCRIPTOR accession="ERS21882372"><IDENTIFIERS><PRIMARY_ID>ERS21882372</PRIMARY_ID><EXTERNAL_ID namespace="BioSample">SAMEA116821676</EXTERNAL_ID></IDENTIFIERS></SAMPLE_DESCRIPTOR><LIBRARY_DESCRIPTOR><LIBRAR