# Storm Research Assistant

Reference
https://github.com/langchain-ai/langgraph/blob/main/examples/storm/storm.ipynb


## Prereqs


In [None]:

# %pip install -U langchain_community langchain_openai langgraph wikipedia  scikit-learn  langchain_fireworks langchain_anthropic
# # We use one or the other search engine below
# %pip install -U tavily-python, playwright
# %pip install -U duckduckgo-search
# # ! apt-get install graphviz graphviz-dev
# %pip install pygraphviz



In [1]:
from storm import *

# LLMS and Embeddings are provided at the top level

fast_llm, _ = get_openai_llms(regular_model="gpt-3.5-turbo", long_context_model="gpt-3.5-turbo-0125")
_, long_context_llm = get_anthropic_llms()

# ollama_model = 'mistral:7b-instruct-q4_K_M'
# fast_llm, long_context_llm = get_ollama_llms(regular_model=ollama_model, long_context_model=ollama_model)


embeddings = get_gpt4all_embeddings()

_ = None



example_topic = "Impact of THE Red Cross Church IN Zimbabwe early history"

interview_graph = StormInterviewGraph(fast_llm)


bert_load_from_file: gguf version     = 2
bert_load_from_file: gguf alignment   = 32
bert_load_from_file: gguf data offset = 695552
bert_load_from_file: model name           = BERT
bert_load_from_file: model architecture   = bert
bert_load_from_file: model file type      = 1
bert_load_from_file: bert tokenizer vocab = 30522


### Generate Initial Outline


In [2]:
# Chains and interview graph

# generate_outline_direct = get_chain_outline(fast_llm)
# expand_chain = get_chain_expand_related_topics(fast_llm)
# gen_perspectives_chain = get_chain_perspective_generator(fast_llm)
# gen_queries_chain = get_chain_queries(fast_llm)
# gen_answer_chain = get_chain_answer(fast_llm)

### Test Chains


In [None]:
# 1. Generate Outline

initial_outline = interview_graph.outline.invoke({"topic": example_topic})
logger.info(initial_outline.as_str)

In [None]:
# 2. Expand related topics

related_subjects = await interview_graph.related_topics.ainvoke({"topic": example_topic})
related_subjects

#### Generate Perspectives

From these related subjects, we can select representative Wiki editors as "subject matter experts" with distinct backgrounds and affiliations. These will help distribute the search process to encourage a more well-rounded final report.


In [3]:
# 3. Generate perspectives
perspectives = await interview_graph.survey_subjects.invoke(example_topic)
perspectives

2024-03-31 23:57:34,240 [MainThread  ] [INFO ]  Survey Subjects for Topic: Impact of THE Red Cross Church IN Zimbabwe early history
2024-03-31 23:57:37,947 [MainThread  ] [INFO ]  Retrieved 6 docs for Topic: Impact of THE Red Cross Church IN Zimbabwe early history


Perspectives(editors=[Editor(affiliation='History of Zimbabwe Research Institute', name='Dr. Nkosi Moyo', role='Historian', description='Dr. Moyo is a renowned historian specializing in the pre-colonial and colonial history of Zimbabwe. With a focus on the Kingdom of Zimbabwe and the colonial era, Dr. Moyo will provide insights into the historical developments of the region.'), Editor(affiliation='International Red Cross and Red Crescent Movement Association', name='Sarah Patel', role='Humanitarian Advocate', description='Sarah is a dedicated humanitarian advocate with a strong commitment to protecting human rights and alleviating suffering. She will contribute perspectives on humanitarian efforts and the role of organizations in Zimbabwe.'), Editor(affiliation='Religious Harmony Council of Zimbabwe', name='Bishop Tendai Chikomo', role='Religious Leader', description='Bishop Chikomo is a respected religious leader promoting interfaith dialogue and understanding in Zimbabwe. He will foc

In [6]:
# 4. Generate questions

question = await interview_graph.generate_question.invoke(initial_state)

question["messages"][0]

2024-03-31 23:58:28,050 [MainThread  ] [INFO ]  Generating question for DrNkosiMoyo


Swapping roles for DrNkosiMoyo
Converted messages for DrNkosiMoyo while swapping roles: 1 messages


TypeError: StormInterviewGraph.__init__.<locals>.<lambda>() got an unexpected keyword argument 'name'

## Expert Dialog

Each wikipedia writer is primed to role-play using the perspectives presented above. It will ask a series of questions of a second "domain expert" with access to a search engine. This generate content to generate a refined outline as well as an updated index of reference documents.

### Interview State

The conversation is cyclic, so we will construct it within its own graph. The State will contain messages, the reference docs, and the editor (with its own "persona") to make it easy to parallelize these conversations.


# Dialog Roles

The graph will have two participants: the wikipedia editor (generate_question), who asks questions based on its assigned role, and a domain expert (`gen_answer_chain), who uses a search engine to answer the questions as accurately as possible.


In [None]:
# gen_qn_prompt = get_chat_prompt_from_prompt_templates([prompts.gen_question_system_generator, prompts.generate_messages_placeholder()])


# def swap_roles(state: InterviewState, name: str) -> InterviewState:

#     # Normalize name
#     name = cleanup_name(name)

#     logger.info(f'Swapping roles for {name}')

#     converted = []
#     for message in state["messages"]:
#         if isinstance(message, AIMessage) and message.name != name:
#             message = HumanMessage(**message.dict(exclude={"type"}))
#         converted.append(message)
    
#     state['messages'] = converted
    
#     logger.info(f'Converted messages for {name} while swapping roles: {len(converted)} messages')
#     return state


# @as_runnable
# async def generate_question(state: InterviewState) -> InterviewState:
#     editor: Editor = state["editor"]

#     name = cleanup_name(editor.name)


#     logger.info(f'Generating question for {name}')

#     gn_chain = (
#         RunnableLambda(swap_roles).bind(name=name)
#         | gen_qn_prompt.partial(persona=editor.persona)
#         | fast_llm
#         | RunnableLambda(tag_with_name).bind(name=name)
#     )
#     result:AIMessage = await gn_chain.ainvoke(state)
#     state["messages"] = ([result])

#     logger.info(f'Generated question for {name}')
#     return state

### Answer questions

The `gen_answer_chain` first generates queries (query expansion) to answer the editor's question, then responds with citations.


In [None]:

queries = await gen_queries_chain.ainvoke(
    {"messages": [HumanMessage(content=question["messages"][0].content)]}
)

queries

In [None]:

async def gen_answer(
    state: InterviewState,
    config: Optional[RunnableConfig] = None,
    name: str = "SubjectMatterExpert",
    max_str_len: int = 15000,
):
    name = cleanup_name(name)

    logger.info(f'START - Generate answers for [{name}]')

    swapped_state = swap_roles(state, name)  # Convert all other AI messages
    
    # Generate search engine queries
    queries:Queries = await gen_queries_chain.ainvoke(swapped_state)

    logger.info(f"Got {len(queries.queries)} search engine queries for [{name}] -\n\t {queries.queries}")

    # Run search engine
    query_results = await search_engine.abatch(
        queries.queries, config, return_exceptions=True
    )
    successful_results = [
        res for res in query_results if not isinstance(res, Exception)
    ]

    all_query_results = {
        res["url"]: res["content"] for results in successful_results for res in results
    }
    
    logger.info(f"Got {len(successful_results)} search engine results for [{name}] - \n\t {all_query_results}")

    # We could be more precise about handling max token length if we wanted to here
    dumped_successful_results = json.dumps(all_query_results)[:max_str_len]
    
    logger.info(f"Dumped {len(dumped_successful_results)} characters for [{name}] - \n\t {dumped_successful_results}")
    
    # Append Questions from Wikipedia and Answers from the search engine
    ai_message_for_queries: AIMessage = get_ai_message(json.dumps(queries.as_dict()))
    
    tool_results_message = generate_human_message(dumped_successful_results)
    
    logger.debug(f"Got {ai_message_for_queries} for [{name}]")
    
    # tool_call = queries["raw"].additional_kwargs["tool_calls"][0]
    # tool_id = tool_call["id"]

    # tool_message = ToolMessage(tool_call_id=tool_id, content=dumped)
    

    swapped_state["messages"].extend([ai_message_for_queries, tool_results_message])
    
    # Only update the shared state with the final answer to avoid
    # polluting the dialogue history with intermediate messages
    try:
        generated: AnswerWithCitations = await gen_answer_chain.ainvoke(swapped_state)
        
        logger.info(f"Genreted final answer {generated} for [{name}] - \n\t {generated.as_str}")

    except Exception as e:
        logger.error(f"Error generating answer for [{name}] - {e}")
        generated = AnswerWithCitations(answer="", cited_urls=[])
    
    cited_urls = set(generated.cited_urls)
    
    # Save the retrieved information to a the shared state for future reference
    cited_references = {k: v for k, v in all_query_results.items() if k in cited_urls}
    
    formatted_message = AIMessage(name=name, content=generated.as_str)
    # Add message to shared state
    # state["messages"].append(formatted_message)
    state["messages"] = add_messages(state["messages"], [formatted_message])
    
    # Update references with cited references
    state["references"] = update_references(state["references"], cited_references)

    logger.info(f'END - generate answer for [{name}]')
    
    return state
    

In [None]:
intial_messages = [prompts.initial_question, generate_human_message(question["messages"][0].content)]

initial_state: InterviewState = {
    "editor": perspectives.editors[0],
    "messages": intial_messages,
    "references": {}
}

example_answer = await gen_answer(initial_state)
example_answer["messages"][-1].content

In [None]:
example_answer["messages"]

# Construct the Interview Graph

Now that we've defined the editor and domain expert, we can compose them in a graph.


In [None]:
builder = StateGraph(InterviewState)

builder.add_node("ask_question", generate_question)
builder.add_node("answer_question", gen_answer)
builder.add_conditional_edges("answer_question", route_messages)
builder.add_edge("ask_question", "answer_question")

builder.set_entry_point("ask_question")
interview_graph = builder.compile().with_config(run_name="Conduct Interviews")

In [None]:
from IPython.display import Image

# comment out if you have not installed pygraphviz
# Image(interview_graph.get_graph().draw_png())

In [None]:

final_step = None

initial_state = {
    "editor": perspectives.editors[0],
    "messages": [
        AIMessage(
            content=f"So you said you were writing an article on {example_topic}?",
            name="SubjectMatterExpert",
        )
    ],
}
async for step in interview_graph.astream(initial_state):
    name = next(iter(step))
    logger.info(f"Processing step: {name}")
    logger.debug("-- ", str(step[name]["messages"])[:300])
    if END in step:
        final_step = step

In [None]:
final_state = next(iter(final_step.values()))


In [None]:
final_state

## Refine Outline

At this point in STORM, we've conducted a large amount of research from different perspectives. It's time to refine the original outline based on these investigations. Below, create a chain using the LLM with a long context window to update the original outline.


In [None]:
refine_outline_prompt = get_chat_prompt_from_prompt_templates([prompts.pmt_s_refine_outline, prompts.pmt_h_refine_outline])

# Using turbo preview since the context can get quite long
refine_outline_chain = get_chain_with_outputparser(refine_outline_prompt, fast_llm, outline_parser)\
    .with_config(run_name="Refine Outline")

# refine_outline_prompt.partial(format_instructions=outline_parser.get_format_instructions()) | long_context_llm | outline_parser

In [None]:
refined_outline = refine_outline_chain.invoke(
    {
        "topic": example_topic,
        "old_outline": initial_outline.as_str,
        "conversations": "\n\n".join(
            f"### {m.name}\n\n{m.content}" for m in final_state["messages"]
        ),
    }
)

In [None]:
logger.info(refined_outline.as_str)

## Generate Article


In [None]:
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=25
)


reference_docs = [
    Document(page_content=v, metadata={"source": k})
    for k, v in final_state["references"].items()
]

logger.info(f"Number of references: {len(reference_docs)}")

vectorstore = get_inmemory_db(reference_docs, embeddings)

# Get contents of the references
full_docs = await fetch_pages_from_refs(reference_docs[:])

# Summarize
summaries = summarize_full_docs(fast_llm, example_topic, full_docs)

# f1 = list(chain.from_iterable(summaries.values()))
f1 = summaries.values()
full_split_docs = text_splitter.split_documents(f1)

vectorstore.add_documents(full_split_docs)

retriever = vectorstore.as_retriever()


In [None]:
full_split_docs

In [None]:
d1 = retriever.invoke("What did the red cross do in Zimbabwe?")
print(d1)

#### Generate Sections

Now you can generate the sections using the indexed docs.


In [None]:


section_writer_prompt = get_chat_prompt_from_prompt_templates([prompts.pmt_s_section_writer, prompts.pmt_h_section_writer])


async def retrieve(inputs: dict):
    docs = await retriever.ainvoke(inputs["topic"] + ": " + inputs["section"])
    formatted = "\n".join(
        [
            f'<Document href="{doc.metadata["source"]}"/>\n{doc.page_content}\n</Document>'
            for doc in docs
        ]
    )
    return {"docs": formatted, **inputs}

wiki_parser = get_pydantic_parser(WikiSection)

section_writer = (
    retrieve
    | section_writer_prompt.partial(format_instructions=wiki_parser.get_format_instructions())
    | long_context_llm
    | wiki_parser
)

In [None]:
section = await section_writer.ainvoke(
    {
        "outline": refined_outline.as_str,
        "section": refined_outline.sections[1].section_title,
        "topic": example_topic,
    }
)
print(section.as_str)

#### Generate final article

Now we can rewrite the draft to appropriately group all the citations and maintain a consistent voice.


In [None]:
prompts.pmt_s_writer = generate_system_chat_prompt("""
You are an expert Wikipedia author. Write the complete wiki article on {topic} using the following section drafts:

{draft}

Strictly follow Wikipedia format guidelines.
""")

prompts.pmt_h_writer = generate_human_chat_prompt("""
Write the complete Wiki article using markdown format. Organize citations using footnotes like "[1]","" avoiding duplicates in the footer. Include URLs in the footer.'
""")


In [None]:


writer_prompt = get_chat_prompt_from_prompt_templates([prompts.pmt_s_writer, prompts.pmt_h_writer])

writer = writer_prompt | long_context_llm | StrOutputParser()

In [None]:
for tok in writer.stream({"topic": example_topic, "draft": section.as_str}):
    print(tok, end="")

## Final Flow

Now it's time to string everything together. We will have 6 main stages in sequence:
.

1. Generate the initial outline + perspectives
2. Batch converse with each perspective to expand the content for the article
3. Refine the outline based on the conversations
4. Index the reference docs from the conversations
5. Write the individual sections of the article
6. Write the final wiki

The state tracks the outputs of each stage.


In [None]:
class ResearchState(TypedDict):
    topic: str
    outline: Outline
    editors: List[Editor]
    interview_results: List[InterviewState]
    # The final sections output
    sections: List[WikiSection]
    article: str

In [None]:
import asyncio


async def initialize_research(state: ResearchState):
    topic = state["topic"]
    coros = (
        generate_outline_direct.ainvoke({"topic": topic}),
        survey_subjects.ainvoke(topic),
    )
    results = await asyncio.gather(*coros)
    return {
        **state,
        "outline": results[0],
        "editors": results[1].editors,
    }


async def conduct_interviews(state: ResearchState):
    topic = state["topic"]
    initial_states = [
        {
            "editor": editor,
            "messages": [
                AIMessage(
                    content=f"So you said you were writing an article on {topic}?",
                    name="SubjectMatterExpert",
                )
            ],
        }
        for editor in state["editors"]
    ]
    # We call in to the sub-graph here to parallelize the interviews
    interview_results = await interview_graph.abatch(initial_states)

    return {
        **state,
        "interview_results": interview_results,
    }


def format_conversation(interview_state):
    messages = interview_state["messages"]
    convo = "\n".join(f"{m.name}: {m.content}" for m in messages)
    return f'Conversation with {interview_state["editor"].name}\n\n' + convo


async def refine_outline(state: ResearchState):
    convos = "\n\n".join(
        [
            format_conversation(interview_state)
            for interview_state in state["interview_results"]
        ]
    )

    updated_outline = await refine_outline_chain.ainvoke(
        {
            "topic": state["topic"],
            "old_outline": state["outline"].as_str,
            "conversations": convos,
        }
    )
    return {**state, "outline": updated_outline}


async def index_references(state: ResearchState):
    all_docs = []
    for interview_state in state["interview_results"]:
        reference_docs = [
            Document(page_content=v, metadata={"source": k})
            for k, v in interview_state["references"].items()
        ]
        all_docs.extend(reference_docs)
    await vectorstore.aadd_documents(all_docs)
    return state


async def write_sections(state: ResearchState):
    outline = state["outline"]
    sections = await section_writer.abatch(
        [
            {
                "outline": refined_outline.as_str,
                "section": section.section_title,
                "topic": state["topic"],
            }
            for section in outline.sections
        ]
    )
    return {
        **state,
        "sections": sections,
    }


async def write_article(state: ResearchState):
    topic = state["topic"]
    sections = state["sections"]
    draft = "\n\n".join([section.as_str for section in sections])
    article = await writer.ainvoke({"topic": topic, "draft": draft})
    return {
        **state,
        "article": article,
    }

#### Create the graph


In [None]:
builder_of_storm = StateGraph(ResearchState)

nodes = [
    ("init_research", initialize_research),
    ("conduct_interviews", conduct_interviews),
    ("refine_outline", refine_outline),
    ("index_references", index_references),
    ("write_sections", write_sections),
    ("write_article", write_article),
]
for i in range(len(nodes)):
    name, node = nodes[i]
    builder_of_storm.add_node(name, node)
    if i > 0:
        builder_of_storm.add_edge(nodes[i - 1][0], name)

builder_of_storm.set_entry_point(nodes[0][0])
builder_of_storm.set_finish_point(nodes[-1][0])
storm = builder_of_storm.compile()

In [None]:
# async for step in storm.astream(
#     {
#         "topic": example_topic,
#     }
# ):
#     name = next(iter(step))
#     print(name)
#     logger.info("-- ", str(step[name])[:300])
#     if END in step:
#         results = step

In [None]:
article = results[END]["article"]

## Render the Wiki

Now we can render the final wiki page!


In [None]:
from IPython.display import Markdown

# We will down-header the sections to create less confusion in this notebook
Markdown(article.replace("\n#", "\n##"))

In [None]:
# Write article to file
with open(f"{example_topic}_article.md", "w") as f:
    f.write(article.replace("\n#", "\n##")