Skip to content

Commit

Permalink
Fixes LC document deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
skrawcz committed May 11, 2024
1 parent 82afa0e commit 60c123b
Showing 1 changed file with 56 additions and 25 deletions.
81 changes: 56 additions & 25 deletions scrapegraphai/graphs/smart_scraper_graph_burr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
SmartScraperGraph Module Burr Version
"""
from typing import Tuple
from typing import Tuple, Union

from burr import tracking
from burr.core import Application, ApplicationBuilder, State, default, when
Expand All @@ -14,6 +14,7 @@
from langchain_community.document_transformers import Html2TextTransformer, EmbeddingsRedundantFilter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core import load as lc_serde
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableParallel
Expand Down Expand Up @@ -67,10 +68,10 @@ def parse_node(state: State, chunk_size: int = 4096) -> tuple[dict, State]:

@action(reads=["user_prompt", "parsed_doc", "doc"],
writes=["relevant_chunks"])
def rag_node(state: State, llm_model: object, embedder_model: object) -> tuple[dict, State]:
# bug around input serialization with tracker
llm_model = OpenAI({"model_name": "gpt-3.5-turbo"})
embedder_model = OpenAIEmbeddings()
def rag_node(state: State, llm_model: str, embedder_model: object) -> tuple[dict, State]:
# bug around input serialization with tracker -- so instantiate objects here:
llm_model = OpenAI({"model_name": llm_model})
embedder_model = OpenAIEmbeddings() if embedder_model == "openai" else None
user_prompt = state["user_prompt"]
doc = state["parsed_doc"]

Expand Down Expand Up @@ -104,8 +105,10 @@ def rag_node(state: State, llm_model: object, embedder_model: object) -> tuple[d

@action(reads=["user_prompt", "relevant_chunks", "parsed_doc", "doc"],
writes=["answer"])
def generate_answer_node(state: State, llm_model: object) -> tuple[dict, State]:
llm_model = OpenAI({"model_name": "gpt-3.5-turbo"})
def generate_answer_node(state: State, llm_model: str) -> tuple[dict, State]:
# bug around input serialization with tracker -- so instantiate objects here:
llm_model = OpenAI({"model_name": llm_model})

user_prompt = state["user_prompt"]
doc = state.get("relevant_chunks",
state.get("parsed_doc",
Expand Down Expand Up @@ -207,21 +210,48 @@ def post_run_step(
):
print(f"Finishing action: {action.name}")

import json

def _deserialize_document(x: Union[str, dict]) -> Document:
if isinstance(x, dict):
return lc_serde.load(x)
elif isinstance(x, str):
try:
return lc_serde.loads(x)
except json.JSONDecodeError:
return Document(page_content=x)


def run(prompt: str, input_key: str, source: str, config: dict) -> str:
# these configs aren't really used yet.
llm_model = config["llm_model"]

embedder_model = config["embedder_model"]
open_ai_embedder = OpenAIEmbeddings()
# open_ai_embedder = OpenAIEmbeddings()
chunk_size = config["model_token"]

tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")
app_instance_id = "testing-12345678919"
initial_state = {
"user_prompt": prompt,
input_key: source,
}
from burr.core import expr
tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")

entry_point = "fetch_node"
if app_instance_id:
persisted_state = tracker.load(None, app_id=app_instance_id, sequence_no=None)
if not persisted_state:
print(f"Warning: No persisted state found for app_id {app_instance_id}.")
else:
initial_state = persisted_state["state"]
# for now we need to manually deserialize LangChain messages into LangChain Objects
# i.e. we know which objects need to be LC objects
initial_state = initial_state.update(**{
"doc": _deserialize_document(initial_state["doc"])
})
docs = [_deserialize_document(doc) for doc in initial_state["relevant_chunks"]]
initial_state = initial_state.update(**{
"relevant_chunks": docs
})
entry_point = persisted_state["position"]

app = (
ApplicationBuilder()
Expand All @@ -236,16 +266,17 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str:
("parse_node", "rag_node", default),
("rag_node", "generate_answer_node", default)
)
# .with_entrypoint("fetch_node")
# .with_state(**initial_state)
.initialize_from(
tracker,
resume_at_next_action=True, # always resume from entrypoint in the case of failure
default_state=initial_state,
default_entrypoint="fetch_node",
)
# .with_identifiers(app_id="testing-123456")
.with_tracker(project="smart-scraper-graph")
.with_entrypoint(entry_point)
.with_state(**initial_state)
# this will work once we get serialization plugin for langchain objects done
# .initialize_from(
# tracker,
# resume_at_next_action=True, # always resume from entrypoint in the case of failure
# default_state=initial_state,
# default_entrypoint="fetch_node",
# )
.with_identifiers(app_id=app_instance_id)
.with_tracker(tracker)
.with_hooks(PrintLnHook())
.build()
)
Expand All @@ -270,8 +301,8 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str:
source = "https://en.wikipedia.org/wiki/Paris"
input_key = "url"
config = {
"llm_model": "rag-token",
"embedder_model": "foo",
"llm_model": "gpt-3.5-turbo",
"embedder_model": "openai",
"model_token": "bar",
}
run(prompt, input_key, source, config)
print(run(prompt, input_key, source, config))

0 comments on commit 60c123b

Please sign in to comment.