In [1]:
from langchain.agents import initialize_agent, Tool, AgentType
from langchain.chat_models import ChatOpenAI
import os
import json
from apitable import Apitable
import tiktoken
from typing import Optional
from pydantic import BaseModel, Field
from typing import Dict, List
from langchain.callbacks import get_openai_callback
from langchain.tools import BaseTool
from typing import Type

In [2]:
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
apitable_api_token = os.getenv("APITABLE_API_TOKEN")
apitable = Apitable(token=apitable_api_token)

In [3]:
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")

def count_tokens(s):
    return len(enc.encode(s))

In [4]:
def trans_key(field_key_map, key: str):
    """
    When there is a field mapping, convert the mapped key to the actual key
    """
    if key in ["_id", "recordId"]:
        return key
    if field_key_map:
        _key = field_key_map.get(key, key)
        return _key
    return key

def query_parse(field_key_map, **kwargs) -> str:
    query_list = []
    for k, v in kwargs.items():
        # Handling null
        if v is None:
            v = "BLANK()"
        # Handling string
        elif isinstance(v, str):
            v = f'"{v}"'
        elif isinstance(v, bool):
            v = "TRUE()" if v else "FALSE()"
        # Handling array type values, multiple select, members?
        elif isinstance(v, list):
            v = f'"{", ".join(v)}"'
        query_list.append(f"{{{trans_key(field_key_map, k)}}}={v}")
    if len(query_list) == 1:
        return query_list[0]
    else:
        qs = ",".join(query_list)
        return f"AND({qs})"

def get_spaces(question: str):
    spaces = apitable.spaces.all()
    return [json.loads(space.json()) for space in spaces]

def get_nodes(space_id: str):
    nodes = apitable.space(space_id=space_id).nodes.all()
    return [json.loads(node.json()) for node in nodes]

def get_fields(datasheet_id: str):
    fields = apitable.datasheet(datasheet_id).fields.all()
    return [json.loads(field.json()) for field in fields]

def create_fields(space_id: str, datasheet_id: str, field_data: dict[str, str]):
    field = (
        apitable.space(space_id)
        .datasheet(datasheet_id)
        .fields.create(field_data)
    )
    return field.json()

def get_records(datasheet_id:str, filter_condition: Optional[dict]= None, sort_condition: Optional[list] = None, maxRecords_condition: Optional[int] = None):
    dst = apitable.datasheet(datasheet_id)
    query_kwargs = {}
    if filter_condition:
        query_formula = query_parse(filter_condition)
        query_kwargs["filterByFormula"] = query_formula
    if sort_condition:
        query_kwargs["sort"] = sort_condition
    if maxRecords_condition:
        query_kwargs["maxRecords"] = maxRecords_condition
    records = dst.records.all(**query_kwargs)
    parsed_records = [record.json() for record in records]
    if count_tokens(str(parsed_records)) > 1000:
        parsed_records_str = (
            "Found "
            + str(len(parsed_records))
            + " records, too many to show. Try to limit"
        )
        return parsed_records_str
    else:
        return parsed_records

In [9]:
class SortCondition(BaseModel):
    field: str = Field(description="Field name")
    order: str = Field(description="Sort order", enum=["desc", "asc"])

class GetRecordsInput(BaseModel):
    """Inputs for get records"""
    datasheet_id: str = Field(description="The ID of the datasheet to retrieve records from.")
    filter_condition: Optional[Dict[str, str]] = Field(
        description="""
            Find records that meet specific conditions.
            This object should contain a key-value pair where the key is the field name and the value is the lookup value. For instance: {"title": "test"}.
            """
    )
    sort_condition: Optional[List[SortCondition]] = Field(min_items=1,
        description="Sort returned records by specific field"
    )
    maxRecords_condition: Optional[int] = Field(
        description="Limit the number of returned values."
    )

class GetRecordsTool(BaseTool):
    name = "get_records"
    description = """
        Useful for retrieving data in a datasheet
        """
    args_schema: Type[BaseModel] = GetRecordsInput

    def _run(self, datasheet_id:str, filter_condition: Optional[dict]= None, sort_condition: Optional[list] = None, maxRecords_condition: Optional[int] = None):
        try:
            records = get_records(datasheet_id, filter_condition, sort_condition, maxRecords_condition)
            return records
        except Exception as e:
            if str(e) == "The sorted field does not exist":
                return f"Function excute failedL: '{e}', please try another function to get right field name."
            elif str(e) == "api_param_formula_error":
                return  f"Function excute failed: '{e}', please try to make right filter_condition."
            else:
                return f"Function excute failed: '{e}', please try another function."

    def _arun(self):
        raise NotImplementedError("get_records does not support async")

In [10]:
tools = [
    Tool(
        name = "get_spaces",
        func=get_spaces,
        description="useful for accessing all spaces the user has access to. Input should be in the form of a question containing full context",
    ),
    Tool(
        name="get_nodes",
        func=get_nodes,
        description="useful for retrieving all file nodes in a space"
    ),
    Tool(
        name="get_fields",
        func=get_fields,
        description="useful for retrieving fields in a datasheet, if you need get"
    ),
    GetRecordsTool()
]

mrkl = initialize_agent(tools, llm, agent=AgentType.OPENAI_FUNCTIONS, verbose=True)

def main(prompt: str):
    with get_openai_callback() as cb:
        mrkl.run(prompt)
        print(f"Total Tokens: {cb.total_tokens}")
        print(f"Prompt Tokens: {cb.prompt_tokens}")
        print(f"Completion Tokens: {cb.completion_tokens}")
        print(f"Successful Requests: {cb.successful_requests}")
        print(f"Total Cost (USD): ${cb.total_cost}")

In [7]:
main("What spaces do I have?")



[1m> Entering new  chain...[0m
[32;1m[1;3m
Invoking: `get_spaces` with `What spaces do I have?`


[0m[36;1m[1;3m[{'id': 'spcS0eZxZ8mSA', 'name': 'APITable Ltd.', 'isAdmin': None}, {'id': 'spctqtTZpssYw', 'name': "xukecheng's Space", 'isAdmin': True}, {'id': 'spcGSVizcwRYF', 'name': 'Gmail Space', 'isAdmin': None}, {'id': 'spcV5VkzU71Lf', 'name': 'test03-bug验证升级', 'isAdmin': None}][0m[32;1m[1;3mYou have the following spaces:

1. APITable Ltd.
2. xukecheng's Space
3. Gmail Space
4. test03-bug验证升级[0m

[1m> Finished chain.[0m
Total Tokens: 642
Prompt Tokens: 585
Completion Tokens: 57
Successful Requests: 2
Total Cost (USD): $0.0009915


In [11]:
main("Tell me the latest value of APITable MAU in xukecheng's space")



[1m> Entering new  chain...[0m
[32;1m[1;3m
Invoking: `get_spaces` with `xukecheng`


[0m[36;1m[1;3m[{'id': 'spcS0eZxZ8mSA', 'name': 'APITable Ltd.', 'isAdmin': None}, {'id': 'spctqtTZpssYw', 'name': "xukecheng's Space", 'isAdmin': True}, {'id': 'spcGSVizcwRYF', 'name': 'Gmail Space', 'isAdmin': None}, {'id': 'spcV5VkzU71Lf', 'name': 'test03-bug验证升级', 'isAdmin': None}][0m[32;1m[1;3m
Invoking: `get_nodes` with `spctqtTZpssYw`


[0m[33;1m[1;3m[{'id': 'dstZTEueDaWedFWgbh', 'name': 'Project management', 'type': 'Datasheet', 'icon': '', 'isFav': False}, {'id': 'fodHl5i2LeMBH', 'name': 'Make Test', 'type': 'Folder', 'icon': '', 'isFav': False}, {'id': 'fomh7eepvbqkBrwCvW', 'name': 'Form', 'type': 'Form', 'icon': '', 'isFav': False}, {'id': 'dstFQKumRsAp4p5RBE', 'name': 'AITEST', 'type': 'Datasheet', 'icon': '', 'isFav': False}, {'id': 'dstWuVdBJNDzpjn8ck', 'name': 'email_oss', 'type': 'Datasheet', 'icon': '', 'isFav': False}, {'id': 'dsti6VpNpuKQpHVSnh', 'name': 'APITable MAUs'