In [52]:
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from enum import Enum
from typing import Any, Type, List
from abc import ABC
from geojson_pydantic import FeatureCollection

class NodeType(str, Enum):
    INPUT_DATA = 'input_data'
    INTERMEDIARY_DATA = 'intermediary_data'
    OPERATION = 'operation'
    RESULT_DATA = 'result_data'

class Node(ABC, BaseModel):
    name: str
    type: NodeType

class InputDataNode(Node):
    type: NodeType = NodeType.INPUT_DATA
    collection_name: str = Field(..., description='Name of the collection that the input data comes from')

class OperationNode(Node):
    type: NodeType = NodeType.OPERATION
    # operation_description: str = Field(..., description='Describes how the operation should manipulate the input data')
    # input_data_description: str = Field(..., description='Describes how the input data is formatted, important field names, etc.')
    # output_data_description: str = Field(..., description='Describes how the output data is formatted, important field names, etc.')

class IntermediaryDataNode(Node):
    type: NodeType = NodeType.INTERMEDIARY_DATA
    path: str = Field(..., description='Relative save path to the file that contains the intermediary data')

class ResultDataNode(Node):
    type: NodeType = NodeType.RESULT_DATA

class Edge(BaseModel):
    tail_name: str
    head_name: str

In [53]:
import graphviz
from typing import List, Dict, Union

class ComputationalGraph:
    color_map: Dict[NodeType, str] = {
        NodeType.INPUT_DATA: 'lightgreen',
        NodeType.INTERMEDIARY_DATA: 'orange',
        NodeType.OPERATION: 'skyblue',
        NodeType.RESULT_DATA: 'deeppink'
    }

    def __init__(self, nodes: List[Node] = [], edges: List[Edge] = [],  filename='output/filled_colorful_organogram.gv'):
        self.nodes: List[Node] = nodes
        self.edges: List[Edge] = edges
        self.graph = graphviz.Digraph(filename=filename)

    def add_node(self, node: Node):
        self.nodes.append(node)

    def generate_graph(self):
        for node in self.nodes:
            self.graph.node(node.name, style='filled', fillcolor=self.color_map[node.type])
        for edge in self.edges:
            self.graph.edge(edge.tail_name, edge.head_name)
        return self.graph

In [54]:
class CreateComputationalGraphInput(BaseModel):
    nodes: List[Union[InputDataNode, OperationNode, IntermediaryDataNode, ResultDataNode]] = Field(..., description='List of nodes')
    edges: List[Edge] = Field(..., description='List of edges between nodes')


class CreateComputationalGraphTool(BaseTool):
    name: str = 'create_computational_graph'
    args_schema: Type[BaseModel] = CreateComputationalGraphInput
    description: str = 'Use this to create computational graph meant to help solve GIS-related problems.'

    def _run(self, nodes: List[Node], edges: List[Edge], *args: Any, **kwargs: Any) -> Any:
        computational_graph = ComputationalGraph(nodes, edges)
        computational_graph.generate_graph().view()
        return (
            f'nodes: {nodes}\n'
            f'edges: {edges}'
        )

In [55]:
from langchain_openai import ChatOpenAI, OpenAI
from langchain.agents import create_openai_tools_agent
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate, PromptTemplate
from langchain.chains import LLMChain
from dotenv import load_dotenv
from langchain.agents import create_openai_tools_agent, AgentExecutor
from langchain.tools.requests.tool import RequestsGetTool
from langchain.requests import RequestsWrapper

from query_collection import QueryOGCAPIFeaturesCollectionTool

load_dotenv('../../../.env')

system_message = """You to create a computation graph for a GIS-related question.
You have four types of nodes () that all implement the abstract Node class. 

class Node(ABC):
    def __init__(self, name, position, style="filled"):
        ...

Create an array of nodes and edges that looks something like this: 

nodes = [
    InputDataNode("mobility_data"),
    InputDataNode("france_data"),
    OperationNode("load_mobility_data"),
    OperationNode("load_france_data"),
    IntermediateDataNode("mobility_df"),
    IntermediateDataNode("france_gdf"),
    OperationNode("filter_mobility_data"),
    OperationNode("merge_change_rates_with_gdf"),
    IntermediateDataNode("monthly_change_rates_df"),
    IntermediateDataNode("merged_gdf"),
    OperationNode("compute_monthly_change_rates"),
    ResultDataNode("france_mobility_df"),
    OperationNode("visualize_monthly_change_rates"),
    ResultDataNode("france_map_matrix"),
    OperationNode("draw_line_chart"),
    ResultDataNode("line_chart")
]

edges = [
    Edge("mobility_data", "load_mobility_data"),
    Edge("load_mobility_data", "mobility_df"),
    Edge("france_data", "load_france_data"),
    Edge("load_france_data", "france_gdf"),
    Edge("mobility_df", "filter_mobility_data"),
    Edge("filter_mobility_data", "france_mobility_df"),
    Edge("france_gdf", "merge_change_rates_with_gdf"),
    Edge("france_mobility_df", "compute_monthly_change_rates"),
    Edge("compute_monthly_change_rates", "monthly_change_rates_df"),
    Edge("monthly_change_rates_df", "merge_change_rates_with_gdf"),
    Edge("merge_change_rates_with_gdf", "merged_gdf"),
    Edge("merged_gdf", "visualize_monthly_change_rates"),
    Edge("visualize_monthly_change_rates", "france_map_matrix"),
    Edge("monthly_change_rates_df", "draw_line_chart"),
    Edge("draw_line_chart", "line_chart")
]

The graph should solve the user's question.
"""

system_message2 = """Create a computational graph for a GIS-related question.
All operation nodes should be followed by an intermediary node.
"""

# llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm = ChatOpenAI(model="gpt-4-0125-preview", temperature=0)

AI_SUFFIX = """I should look at the available data at http://localhost:9000/collections and available properties in these to see what is usable for the problem at hand.
I should then use relevant datasets to create a computational graph that can solve the question from the user.   
"""

prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessage(content=system_message2),
        HumanMessagePromptTemplate(prompt=PromptTemplate(
            input_variables=['input'], template='{input}')),
        # AIMessage(content=AI_SUFFIX),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ]
)

tools = [CreateComputationalGraphTool(), RequestsGetTool(
    requests_wrapper=RequestsWrapper())]

agent = create_openai_tools_agent(llm=llm, prompt=prompt, tools=tools)

agent_executor = AgentExecutor(
    agent=agent,
    tools=tools,
    verbose=True,
)

res = agent_executor.invoke(
    {'input': 'Get me the road segments of the noisiest roads in Trondheim.'})



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `requests_get` with `http://localhost:9000/collections`


[0m[33;1m[1;3m{"links":[{"href":"http://localhost:9000/collections","rel":"self","type":"application/json","title":"This document as JSON"},{"href":"http://localhost:9000/collections.html","rel":"alternate","type":"text/html","title":"This document as HTML"}],"collections":[{"id":"public.5001_25832_ar50_gml","title":"public.5001_25832_ar50_gml","description":"Data for table public.5001_25832_ar50_gml","extent":{"spatial":{"crs":"http://www.opengis.net/def/crs/EPSG/0/4326","bbox":[0,0,0,0]}},"links":[{"href":"http://localhost:9000/collections/public.5001_25832_ar50_gml","rel":"self","type":"application/json","title":"Metadata as JSON"},{"href":"http://localhost:9000/collections/public.5001_25832_ar50_gml.html","rel":"alternate","type":"text/html","title":"Metadata as HTML"},{"href":"http://localhost:9000/collections/public.5001_25832_ar50_gml/items","rel":

X connection to :0 broken (explicit kill or server shutdown).
X connection to :0 broken (explicit kill or server shutdown).
X connection to :0 broken (explicit kill or server shutdown).
X connection to :0 broken (explicit kill or server shutdown).
X connection to :0 broken (explicit kill or server shutdown).
X connection to :0 broken (explicit kill or server shutdown).
