<img src="/home/vule/projects/learn_agent/images/LATS.png" width="700"/>

In [1]:
cd ..

/home/vule/projects/learn_agent


## Language Agent Tree Search

In [2]:
import math
from collections import deque
from typing import Optional

from langchain_core.messages import AIMessage, HumanMessage, ToolMessage, BaseMessage

from pydantic import BaseModel, Field


# define format input/output for the agnets
class Reflection(BaseModel):
    reflection: str = Field(description="The critique and reflection on the sufficiency, superfluency, and general quality of the response.")
    score: int = Field(description="The score of the response, from 0 to 10", ge=0, le=10)
    found_solution: bool = Field(description="Whether the response has fully solve the question or task.")
    
    def as_message(self):
        return HumanMessage(content=f"Resoning: {self.reflection}\nScore: {self.score}\nFound Solution: {self.found_solution}")

    @property
    def normalized_score(self) -> float:
        return self.score / 10

    @property
    def is_good_score(self) -> bool:
        return self.normalized_score >= 0.8


# build tree search
class Node:
    def __init__(self, messages: list[BaseMessage], reflection: Reflection, parent: Optional['Node'] = None):
        self.messages = messages
        self.parent = parent
        self.children = []
        self.value = 0
        self.visits = 0
        self.reflection = reflection
        self.depth = parent.depth + 1 if parent else 0
        self._is_solved = reflection.found_solution if reflection else False
        if self._is_solved:
            self._mark_tree_as_solved()
        self.backpropagate(reflection.normalized_score)

    # display the node
    def __repr__(self) -> str:
        return (
            f"<Node value={self.value}, visits={self.visits}, solution={self.messages}, reflection={self.reflection}>"
        )

    @property
    def is_solved(self) -> bool:
        return self._is_solved

    @property
    def is_terminal(self) -> bool:
        return len(self.children) == 0

    @property
    def best_child_score(self) -> float:
        if not self.children:
            return None
        return max(child.value * int(child.is_solved) for child in self.children)

    @property
    def height(self) -> int:
        if not self.parent: return 1
        return 1 + max(child.height for child in self.children)
    
    def upper_confidence_bound(self, exploration_weight: float = 1.0) -> float:
        """Return the UCB score. This helps to balance exploration and exploitation."""
        if self.parent is None: raise ValueError("Root node has no parent")
        if self.visits == 0: return self.value
        # Encourage exploitation of high-value trajectories
        average_reward = self.value / self.visits
        # Encourage exploration of less-explored trajectories
        exploration_term = exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
        return average_reward + exploration_term

    def backpropagate(self, score: float):
        """Update the score of this node and all its ancestors."""
        self.value += score
        self.visits += 1
        if self.parent:
            self.parent.backpropagate(score)

    def get_messages(self) -> list[BaseMessage]:
        return self.messages

    def get_trajectory(self, include_reflection: bool = False) -> list[BaseMessage]:
        """Return the trajectory of messages from the root to this node."""
        return self.messages

    def _get_all_children(self) -> list['Node']:
        """Return all children of this node."""
        # BFS to get all children
        all_nodes = []
        nodes = deque([self])
        while nodes:
            node = nodes.popleft()
            all_nodes.append(node)
            nodes.extend(node.children)
        return all_nodes

    def get_best_solution(self):
        """Return the best solution from this subtree."""
        all_nodes = [self] + self._get_all_children()
        best_node = max(all_nodes, key=lambda node: int(node.is_solved and node.is_terminal) * node.value)
        return best_node

    def _mark_tree_as_solved(self):
        """Mark the entire tree as solved."""
        parent = self.parent
        while parent:
            parent._is_solved = True
            parent = parent.parent

In [3]:
from typing_extensions import TypedDict


# define state of the graph
class TreeState(TypedDict):
    # the full tree
    root: Node
    # the original input
    input: str

In [4]:
# llm open ai
from src.key import openai_api_key
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key)

In [5]:
# define actions/tools
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langgraph.prebuilt import ToolNode

search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)
search_tools = [tavily_tool]
tool_node = ToolNode(tools=search_tools)

In [6]:
# prompt
from langchain_core.output_parsers.openai_tools import (
    JsonOutputToolsParser,
    PydanticToolsParser,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import chain as as_runnable

reflection_prompt_template = ChatPromptTemplate.from_messages([
    ("system", "Reflect and grade the assistant response to the user question below."),
    ("user", "{input}"),
    MessagesPlaceholder(variable_name="candidate"),
])

reflection_llm_chain = (
    reflection_prompt_template
    | llm.bind_tools(tools=[Reflection], tool_choice="Reflection").with_config(
        run_name="Reflection"
    )
    | PydanticToolsParser(tools=[Reflection])
)

# test reflection_llm_chain
# tools = reflection_llm_chain.invoke({"input": "What is the capital of France?", "candidates": [AIMessage(content="Paris.")]})
# tools[0]

@as_runnable
def reflection_chain(inputs) -> Reflection:
    tool_choices = reflection_llm_chain.invoke(inputs)
    reflection = tool_choices[0]
    if not isinstance(inputs["candidate"][-1], AIMessage):
        reflection.found_solution = False
    return reflection
    # return AIMessage(content=f"Reflection: {reflection.reflection}\nScore: {reflection.score}\nFound Solution: {reflection.found_solution}")

**Initial Response**

In [7]:
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.runnables import RunnableConfig

response_prompt_template = ChatPromptTemplate.from_messages([
    ("system", "You are an AI assistant."),
    ("user", "{input}"),
    MessagesPlaceholder(variable_name="messages", optional=True),
])

initial_answer_chain = response_prompt_template | llm.bind_tools(tools=search_tools)

parser = JsonOutputToolsParser(return_id=True)

initial_answer_chain.invoke({"input": "What is the capital of France?"})


AIMessage(content='The capital of France is Paris.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 9, 'prompt_tokens': 92, 'total_tokens': 101, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_01aeff40ea', 'finish_reason': 'stop', 'logprobs': None}, id='run-ffa8ee39-9d41-4cb6-a1f2-f0baef87f710-0', usage_metadata={'input_tokens': 92, 'output_tokens': 9, 'total_tokens': 101, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})

In [8]:
initial_response = initial_answer_chain.invoke(
    {"input": "Write a research report on lithium pollution."}
)
initial_response.tool_calls

[{'name': 'tavily_search_results_json',
  'args': {'query': 'lithium pollution report'},
  'id': 'call_AnynvpY1giVW6NQ3rYgmwQXQ',
  'type': 'tool_call'}]

In [9]:
initial_response

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_AnynvpY1giVW6NQ3rYgmwQXQ', 'function': {'arguments': '{"query":"lithium pollution report"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 23, 'prompt_tokens': 93, 'total_tokens': 116, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_f2cd28694a', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-6c2cc39b-ee73-4e8f-bda0-3790e9087db9-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'lithium pollution report'}, 'id': 'call_AnynvpY1giVW6NQ3rYgmwQXQ', 'type': 'tool_call'}], usage_metadata={'input_tokens': 93, 'output_tokens': 23, 'total_tokens': 116, 'input_token_details': {'audio': 0,

In [10]:
parser.invoke(initial_response)

[{'args': {'query': 'lithium pollution report'},
  'id': 'call_AnynvpY1giVW6NQ3rYgmwQXQ',
  'type': 'tavily_search_results_json'}]

In [11]:
tool_response = tool_node.invoke({'messages': [initial_response]})

In [12]:
for o in tool_response['messages']:
    print(o)



content='[{"url": "https://poweringautos.com/how-much-pollution-yoes-lithium-ion-battery-production-cause/", "content": "The main sources of pollution in lithium-ion battery production include raw material extraction, manufacturing processes, chemical waste, and end-of-life disposal. In summary, lithium-ion battery production can lead to considerable pollution emissions that impact both the environment and human health, necessitating a careful consideration of sustainable practices in this expanding industry. In summary, lithium-ion battery production can generate significant carbon emissions ranging from 150 to 200 kg of CO2 per kWh. Various factors affect this outcome, including raw material extraction methods and energy sources. Lithium-ion battery production significantly impacts water resources through the extraction and processing of lithium and other materials. A research article published in the Journal of Cleaner Production (Buchanan et al., 2020) reported that the production 

In [28]:
# candidate generation and reflection into one function
# create note in graph
def generate_initial_response(state: TreeState, messages: list[BaseMessage] = [], parent: Node = None) -> dict:
    """Generate the initial candidate response"""
    res = initial_answer_chain.invoke({"input": state["input"], "messages": messages})
    tool_response = tool_node.invoke({'messages': [res]})
    output_messages = messages + [res] + tool_response['messages']
    reflection = reflection_chain.invoke({"input": state["input"], "candidate": output_messages})
    # convert reflection to AIMessage
    reflection_message = AIMessage(content=f"Reflection: {reflection.reflection}\nScore: {reflection.score}\nFound Solution: {reflection.found_solution}")
    output_messages+= [reflection_message]
    node = Node(messages=output_messages, reflection=reflection, parent=parent)
    return {
        "root": node,
        "input": state["input"],
    }


# test
input = "Write a research report on lithium pollution."
initial_state = generate_initial_response({"input": input})


In [29]:
initial_state['root'].messages

[AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_2OSYzQCRnmN9g2db9jDxyMyq', 'function': {'arguments': '{"query":"lithium pollution research report"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 24, 'prompt_tokens': 93, 'total_tokens': 117, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_01aeff40ea', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-4ca18a99-03f6-4795-8f6b-bb4df74a7791-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'lithium pollution research report'}, 'id': 'call_2OSYzQCRnmN9g2db9jDxyMyq', 'type': 'tool_call'}], usage_metadata={'input_tokens': 93, 'output_tokens': 24, 'total_tokens': 117, 'input_token_det

In [30]:
print(initial_state['root'].reflection.reflection)

The assistant's response seems to be missing. It should include a structured research report on lithium pollution, addressing various aspects such as the sources of pollution, its environmental impact, health risks, and potential solutions. Instead, the assistant provided search results for related articles rather than a comprehensive report. This approach does not answer the user’s request adequately. The response lacks depth, organization, and the synthesis of information that the user likely expected in a research report. Therefore, the response is inadequate and does not fulfill the user's request for a complete report.


In [31]:
from collections import defaultdict

def select_best_node(root: Node) -> dict:
    """Starting from the root node, select the best node to expand"""    
    note = root

    while note.children:
        note = max(note.children, key=lambda child: child.upper_confidence_bound())
    return note


def expand(state: TreeState) -> dict:
    """Expand the selected node"""
    n_candidates = 5
    root = state['root']
    best_candidate = select_best_node(root)
    messages = best_candidate.get_trajectory()
    new_candidates = [generate_initial_response(state, messages, parent=best_candidate)['node'] for i in range(n_candidates)]
    best_candidate.children = new_candidates
    return state

root = initial_state['root']

best_candidate = select_best_node(root)



In [32]:
n_candidates = 2
state = initial_state
new_candidates = [generate_initial_response(state, best_candidate.messages, parent=best_candidate)['root'] for i in range(n_candidates)]


In [33]:
for o in new_candidates:
    print('--------------')
    for m in o.messages:
        print(m.content)


--------------

[{"url": "https://pubs.acs.org/doi/10.1021/acs.est.4c00225", "content": "Lithium (Li) is an important resource that drives sustainable mobility and renewable energy. Its demand is projected to continue to increase in the coming decades. However, the risk of Li pollution has also emerged as a global concern. Here, we investigated the pollution characteristics, sources, exposure levels, and associated health risks of Li in the Jinjiang River basin, the largest area"}, {"url": "https://pubs.rsc.org/en/content/articlelanding/2021/ee/d1ee00691f", "content": "Energy & Environmental Science\nEnvironmental impacts, pollution sources and pathways of spent lithium-ion batteries\n*\nCorresponding authors\na\nSchool of Engineering, Newcastle University, Newcastle upon Tyne, UK\nb\nFaraday Institution (ReLIB project), Quad One, Harwell Science and Innovation Campus, Didcot, UK\nc\nFaraday Institution (SafeBatt project), Quad One, Harwell Science and Innovation Campus, Didcot, UK\nE-