diff --git a/scrapegraphai/graphs/smart_scraper_graph_burr.py b/scrapegraphai/graphs/smart_scraper_graph_burr.py index b6cc03da..ff76da2a 100644 --- a/scrapegraphai/graphs/smart_scraper_graph_burr.py +++ b/scrapegraphai/graphs/smart_scraper_graph_burr.py @@ -1,7 +1,7 @@ """ SmartScraperGraph Module Burr Version """ -from typing import Tuple +from typing import Tuple, Union from burr import tracking from burr.core import Application, ApplicationBuilder, State, default, when @@ -14,6 +14,7 @@ from langchain_community.document_transformers import Html2TextTransformer, EmbeddingsRedundantFilter from langchain_community.vectorstores import FAISS from langchain_core.documents import Document +from langchain_core import load as lc_serde from langchain_core.output_parsers import JsonOutputParser from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnableParallel @@ -67,10 +68,10 @@ def parse_node(state: State, chunk_size: int = 4096) -> tuple[dict, State]: @action(reads=["user_prompt", "parsed_doc", "doc"], writes=["relevant_chunks"]) -def rag_node(state: State, llm_model: object, embedder_model: object) -> tuple[dict, State]: - # bug around input serialization with tracker - llm_model = OpenAI({"model_name": "gpt-3.5-turbo"}) - embedder_model = OpenAIEmbeddings() +def rag_node(state: State, llm_model: str, embedder_model: object) -> tuple[dict, State]: + # bug around input serialization with tracker -- so instantiate objects here: + llm_model = OpenAI({"model_name": llm_model}) + embedder_model = OpenAIEmbeddings() if embedder_model == "openai" else None user_prompt = state["user_prompt"] doc = state["parsed_doc"] @@ -104,8 +105,10 @@ def rag_node(state: State, llm_model: object, embedder_model: object) -> tuple[d @action(reads=["user_prompt", "relevant_chunks", "parsed_doc", "doc"], writes=["answer"]) -def generate_answer_node(state: State, llm_model: object) -> tuple[dict, State]: - llm_model = OpenAI({"model_name": "gpt-3.5-turbo"}) +def generate_answer_node(state: State, llm_model: str) -> tuple[dict, State]: + # bug around input serialization with tracker -- so instantiate objects here: + llm_model = OpenAI({"model_name": llm_model}) + user_prompt = state["user_prompt"] doc = state.get("relevant_chunks", state.get("parsed_doc", @@ -207,21 +210,48 @@ def post_run_step( ): print(f"Finishing action: {action.name}") +import json + +def _deserialize_document(x: Union[str, dict]) -> Document: + if isinstance(x, dict): + return lc_serde.load(x) + elif isinstance(x, str): + try: + return lc_serde.loads(x) + except json.JSONDecodeError: + return Document(page_content=x) + def run(prompt: str, input_key: str, source: str, config: dict) -> str: + # these configs aren't really used yet. llm_model = config["llm_model"] - embedder_model = config["embedder_model"] - open_ai_embedder = OpenAIEmbeddings() + # open_ai_embedder = OpenAIEmbeddings() chunk_size = config["model_token"] + tracker = tracking.LocalTrackingClient(project="smart-scraper-graph") + app_instance_id = "testing-12345678919" initial_state = { "user_prompt": prompt, input_key: source, } - from burr.core import expr - tracker = tracking.LocalTrackingClient(project="smart-scraper-graph") - + entry_point = "fetch_node" + if app_instance_id: + persisted_state = tracker.load(None, app_id=app_instance_id, sequence_no=None) + if not persisted_state: + print(f"Warning: No persisted state found for app_id {app_instance_id}.") + else: + initial_state = persisted_state["state"] + # for now we need to manually deserialize LangChain messages into LangChain Objects + # i.e. we know which objects need to be LC objects + initial_state = initial_state.update(**{ + "doc": _deserialize_document(initial_state["doc"]) + }) + docs = [_deserialize_document(doc) for doc in initial_state["relevant_chunks"]] + initial_state = initial_state.update(**{ + "relevant_chunks": docs + }) + entry_point = persisted_state["position"] app = ( ApplicationBuilder() @@ -236,16 +266,17 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str: ("parse_node", "rag_node", default), ("rag_node", "generate_answer_node", default) ) - # .with_entrypoint("fetch_node") - # .with_state(**initial_state) - .initialize_from( - tracker, - resume_at_next_action=True, # always resume from entrypoint in the case of failure - default_state=initial_state, - default_entrypoint="fetch_node", - ) - # .with_identifiers(app_id="testing-123456") - .with_tracker(project="smart-scraper-graph") + .with_entrypoint(entry_point) + .with_state(**initial_state) + # this will work once we get serialization plugin for langchain objects done + # .initialize_from( + # tracker, + # resume_at_next_action=True, # always resume from entrypoint in the case of failure + # default_state=initial_state, + # default_entrypoint="fetch_node", + # ) + .with_identifiers(app_id=app_instance_id) + .with_tracker(tracker) .with_hooks(PrintLnHook()) .build() ) @@ -270,8 +301,8 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str: source = "https://en.wikipedia.org/wiki/Paris" input_key = "url" config = { - "llm_model": "rag-token", - "embedder_model": "foo", + "llm_model": "gpt-3.5-turbo", + "embedder_model": "openai", "model_token": "bar", } - run(prompt, input_key, source, config) + print(run(prompt, input_key, source, config))