In [14]:
from lib.models import OpenAIModel
from lib.agents import (
    FreeLambda,
    LLMLambda,
    Planner,
    PlannerDomainReflector,
    ToolsExecutor,
    Agent
)
from enum import Enum
from typing import Union

In [15]:
import os
os.environ['OPENAI_API_KEY'] = ""

# Google API

In [16]:
import os.path
import string

from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError

def column_to_index(column):
    index = 0
    for i, char in enumerate(reversed(column)):
        index += (string.ascii_uppercase.index(char) + 1) * (26 ** i)
    return index - 1

def index_to_column(index):
    column = ""
    while index >= 0:
        index, remainder = divmod(index, 26)
        column = string.ascii_uppercase[remainder] + column
        index -= 1
    return column

def parse_range(sheetId, cell_range):
    start_cell, end_cell = cell_range.split(':')
    
    start_col = ''.join(filter(str.isalpha, start_cell))
    start_row = ''.join(filter(str.isdigit, start_cell))
    
    end_col = ''.join(filter(str.isalpha, end_cell))
    end_row = ''.join(filter(str.isdigit, end_cell))
    
    return {
        "sheetId": sheetId,
        "startRowIndex": int(start_row) - 1,
        "endRowIndex": int(end_row),
        "startColumnIndex": column_to_index(start_col),
        "endColumnIndex": column_to_index(end_col) + 1
    }
    
    
class GooogleSheetsApi:
    def __init__(self, credentials, user_data):
        SCOPES = ["https://www.googleapis.com/auth/spreadsheets", "https://www.googleapis.com/auth/drive.file"]
        self.creds = None
        if os.path.exists(user_data):
            self.creds = Credentials.from_authorized_user_file(user_data, SCOPES)
        if not self.creds or not self.creds.valid:
            if self.creds and self.creds.expired and self.creds.refresh_token:
                self.creds.refresh(Request())
            else:
                flow = InstalledAppFlow.from_client_secrets_file(
                    credentials, SCOPES
                )
                self.creds = flow.run_local_server(port=0)
            with open(user_data, "w") as token:
                token.write(self.creds.to_json())
    
    
    def set_sheet_id(self, spreadsheet_id):
        self.spreadsheet_id = spreadsheet_id
    
    
    def create_spreadsheet(self, title):
        service = build("sheets", "v4", credentials=self.creds)
        spreadsheet = {"properties": {"title": title}}
        spreadsheet = (
            service.spreadsheets()
            .create(body=spreadsheet, fields="spreadsheetId")
            .execute()
        )
        return spreadsheet.get("spreadsheetId")
    
    
    def read_values(self, range_name='A1:AZ100000'):
        service = build("sheets", "v4", credentials=self.creds)
        result = (
            service.spreadsheets()
            .values()
            .get(spreadsheetId=self.spreadsheet_id, range=range_name)
            .execute()
        )
        rows = result.get("values", [])
        return rows


    def write_values(self, range_name, values):
        service = build("sheets",  "v4", credentials=self.creds)
        body = {"values": values}
        body = (
            service.spreadsheets()
             .values()
             .update(
                spreadsheetId=self.spreadsheet_id, 
                range=range_name, 
                valueInputOption="USER_ENTERED",
                body=body)
             .execute()
         )
        return body


    def autofill(self, range_name):
        service = build("sheets",  "v4", credentials=self.creds)
        requests = []
        requests.append(
            {
                "autoFill": {
                    "useAlternateSeries": False,
                    "range": parse_range(0, range_name)
                }
            }
        )
        body = {"requests": requests}
        response = (
            service.spreadsheets()
            .batchUpdate(spreadsheetId=self.spreadsheet_id, body=body)
            .execute()
        )
        return response
    
    
    def repeat_formula(self, formula, range_name):
        service = build("sheets",  "v4", credentials=self.creds)
        requests = []
        requests.append(
            {
                "repeatCell": {
                    "range": parse_range(0, range_name),
                    "cell": {
                        "userEnteredValue": {
                            "formulaValue": formula
                        }
                    },
                    "fields": "userEnteredValue"
                }
            }
        )
        body = {"requests": requests}
        response = (
            service.spreadsheets()
            .batchUpdate(spreadsheetId=self.spreadsheet_id, body=body)
            .execute()
        )
        return response
    
    
    def create_pivot_table(self, source_range, target_cell, rows=[], columns=[], values=[]):
        service = build("sheets",  "v4", credentials=self.creds)
        requests = []
        requests.append(
            {
                "updateCells": {
                    "rows": {
                        "values": [
                            {
                                "pivotTable": {
                                    "source": parse_range(0, source_range),
                                    "rows": [
                                        {
                                            "sourceColumnOffset": row,
                                            "sortOrder": "ASCENDING",
                                            "showTotals": False,
                                        } for row in rows
                                    ],
                                    "columns": [
                                        {
                                            "sourceColumnOffset": col,
                                            "sortOrder": "ASCENDING",
                                            "showTotals": False,
                                        } for col in columns
                                    ],
                                    "values": [
                                        {
                                            "summarizeFunction": func,
                                            "sourceColumnOffset": val,
                                        } for val, func in values
                                    ],
                                    "valueLayout": "HORIZONTAL",
                                }
                            }
                        ]
                    },
                    "start": {
                        "sheetId": 0,
                        "rowIndex": int(''.join(filter(str.isdigit, target_cell))),
                        "columnIndex": column_to_index(''.join(filter(str.isalpha, target_cell))),
                    },
                    "fields": "pivotTable",
                }
            }
        )

        body = {"requests": requests}
        response = (
            service.spreadsheets()
            .batchUpdate(spreadsheetId=self.spreadsheet_id, body=body)
            .execute()
        )
        return response


    def create_chart(self, chart_type, target_cell, domain_range, chart_series, title="", botton_axis_title="", left_axis_title=""):
        service = build("sheets",  "v4", credentials=self.creds)
        requests = []
        requests.append(
            {
                "addChart": {
                    "chart": {
                    "spec": {
                        "title": title,
                        "basicChart": {
                        "chartType": chart_type, # LINE, COLUMN, AREA
                        "legendPosition": "BOTTOM_LEGEND",
                        "axis": [
                            {
                            "position": "BOTTOM_AXIS",
                            "title": botton_axis_title
                            },
                            {
                            "position": "LEFT_AXIS",
                            "title": left_axis_title
                            }
                        ],
                        "domains": [
                            {
                            "domain": {
                                "sourceRange": {
                                "sources": parse_range(0, domain_range)
                                }
                            }
                            }
                        ],
                        "series": [
                            {
                            "series": {
                                "sourceRange": {
                                "sources": [parse_range(0, serie)]
                                }
                            },
                            "targetAxis": "LEFT_AXIS"
                            } for serie in chart_series 
                        ],
                        }
                    },
                    "position": {
                        "overlayPosition": {
                        "anchorCell": {
                            "sheetId": 0,
                            "rowIndex": int(''.join(filter(str.isdigit, target_cell)))-1,
                            "columnIndex": column_to_index(''.join(filter(str.isalpha, target_cell))),
                        }
                        }
                    }
                    }
                }
            }
        )
        body = {"requests": requests}
        response = (
            service.spreadsheets()
            .batchUpdate(spreadsheetId=self.spreadsheet_id, body=body)
            .execute()
        )
        return response
    
    
    def get_metadata(self):
        service = build("sheets",  "v4", credentials=self.creds)
        response  =  (
            service.spreadsheets()
             .get(spreadsheetId=self.spreadsheet_id)
             .execute()
         )
        return response


    def get_errors(self):
        res = self.read_values("A1:V50")
        return [(index_to_column(col) + str(row+1), el) for row, r in enumerate(res) for col, el in enumerate(r) if len(el) > 0 and el[0] == '#']

# Tools

In [17]:
import json
import numpy as np

from pydantic import BaseModel, Field, Extra
from typing import Type, TypedDict, Annotated, List

# ------------------------------------------------- Auxiliary functions  ----------------------------------------------------
def matrix_to_markdown(header, matrix):
    result = "| " + " | ".join(header) + " |\n"
    result += "| " + " | ".join(["---"] * len(header)) + " |\n"
    for row in matrix:
        result += "| " + " | ".join(map(str, row)) + " |\n"
    return result

def column_to_index(column):
    index = 0
    for i, char in enumerate(reversed(column)):
        index += (string.ascii_uppercase.index(char) + 1) * (26 ** i)
    return index - 1

def index_to_column(index):
    column = ""
    while index >= 0:
        index, remainder = divmod(index, 26)
        column = string.ascii_uppercase[remainder] + column
        index -= 1
    return column

def markdown_to_matrix(markdown):
    for _ in range(4):
        markdown = markdown.replace('\\n', '\n')
    rows = [row for row in markdown.split('\n') if row != '']
    headers = [title.strip() for title in rows[0].split('|') if title != '']
    delims = ''.join([title.strip() for title in rows[1].split('|') if title != ''])
    matrix = [headers]
    for row in rows[2:]:
        matrix.append([item.strip() for item in row.split('|') if item != ''])
    return matrix

def longest_leftmost_non_empty_subarray(lst):
    i = 0
    while i < len(lst) and lst[i] == '':
        i += 1
    result, indicies = [], []
    while i < len(lst) and lst[i] != '':
        result.append(lst[i])
        indicies.append(i)
        i += 1
    return result, indicies

def get_data_outline(gapi, max_rows=20):
    max_rows = 20
    values_list = gapi.read_values()
    if len(values_list) == 0:
        return "No data on the sheet", "No data on the sheet"
    
    first_row = 0
    while len(values_list[first_row]) == 0:
        first_row += 1
    values_list = values_list[first_row:]
    first_row += 1
    
    res, ind = longest_leftmost_non_empty_subarray(values_list[0])
    table = np.array(values_list).T[ind].T
    if table.shape[0] > max_rows+1:
        header = table[0].reshape(1, -1)
        table = table[1:]
        selected_indices = np.random.choice(table.shape[0], max_rows, replace=False)
        table = table[selected_indices, :]
        table = np.concatenate([header, table])
    table = table.tolist()
    return matrix_to_markdown(table[0], table[1:]), f"{chr(ord('A') + ind[0])}{first_row}:{chr(ord('A') + ind[-1])}{len(values_list)+first_row-1}"


def describe_pivot_in_cell(api, cell):
    res = api.get_metadata()['sheets'][0]['properties']['gridProperties']
    max_row = res['rowCount']
    max_col = index_to_column(res['columnCount'] - 1)
    data = api.read_values(f"{cell}:{max_col}{max_row}")
    row_number = len(data)
    col_number = max(len(row) for row in data)
    start_col = ''.join(filter(str.isalpha, cell))
    start_row = int(''.join(filter(str.isdigit, cell))) + 1
    end_col = index_to_column(column_to_index(start_col) + col_number - 1)
    end_row = start_row + row_number - 2
    return \
    f"""
    Pivot table was written to cell {cell}
    Pivot table main data is located in range {start_col}{start_row}:{end_col}{end_row}
    Pivot table has {row_number-2} rows values and {col_number-1} columns values
    """

# ------------------------------------------------- Descriptions -----------------------------------------------------------

class WriteMarkdownTable(BaseModel):
    """
    Call this function with a cell ID and a markdown table to write this table into the specified cell.
    Markdown table must have headers. Pass it into argument as a string.
    Formulas should begin with "=".
    """
    range_name: str = Field(description="The name of the range on this sheet in A1 notation.")
    markdown_table: str = Field(description="Markdown table that will be written into the range.")
    
class WriteValue(BaseModel):
    """
    Call this function with a cell ID and a value of formula to write this value or formula into only one specified cell.
    Only use formulas you have an access to. If formula has arguments, they should be separated by commas.
    Example of a formula: "=SUM(A1:A5)"
    """
    cell_id: str = Field(description="ID of the cell in A1 notation.")
    value: str = Field(description="Value or formula that will be written into the cell. Formula should start with symbol '='.")
    
class PivotPosition(BaseModel): 
    column_index: int
    aggregation_function: str    
    
class CreatePivotTable(BaseModel):
    """
    Call this function to create a pivot table with data on the sheet.
    """    
    source_range: str = Field(description="The name of the source range of pivot table in A1 notation. Defines where data will be read from.")
    target_cell: str = Field(description="ID of the target cell in A1 notation. Pivot table will be written into this cell.")
    rows: list[int] = Field(description=\
        """
        Column indices in the source range that will be rows in the pivot table. Numbering starts from 0. Should be list of non-negative integers.
        For example: [0, 2] means that columns with indices 1 and 3 will be rows in the pivot table.
        List can be empty, then there will be no rows in pivot table.
        """
    )
    columns: list[int] = Field(description=\
        """
        Column indices in the source range that will be columns in the pivot table. Numbering starts from 0. Should be list of non-negative integers.
        For example: [1, 3] means that columns with indices 1 and 3 will be columns in the pivot table.
        List can be empty, then there will be no columns in pivot table.
        """
    )
    values: list[PivotPosition] = Field(description=\
        """
        Column indices in the source range that will be data values the pivot table. Numbering starts from 0. 
        Should be list of tuples: the first tuple element is column index, the second tuple element is aggregation function.
        Aggregation function can be one of these: "SUM", "COUNTA", "AVERAGE", "MAX", "MIN"
        For example: [(4, "SUM")] means that columns with index 4 will summed up in the pivot table.
        List can be empty, then no values will be written into the pivot table.
        """
    )

class DrawChart(BaseModel):
    """
    Call this function to create a line, column or area chart with data on the sheet.
    """
    chart_type: str = Field(description="Type of the chart, can be on of the following: LINE, COLUMN, AREA")
    target_cell: str = Field(description="ID of the target cell in A1 notation. Chart will be written into this cell.")
    domain_range: str = Field(description="Range name of chart domain in A1 notation (like K9:M123). This domain will be the X axis.")
    chart_series: list[str] = Field(description=\
        """
        List of range names in A1 notation. Each element of this list defines range name of data for chart series.
        For example: ["I5:I19", "J5:J19"] means that there will be two series on the chart: with data from range "I5:I19" and "J5:J19".
        """
    )
    title: str = Field(description="Title of the chart")
    x_title: str = Field(description="Title of the X axis")
    y_title: str = Field(description="Title of the Y axis (or series)")
    
class RepeatFormula(BaseModel):
    """
    Call this function to repeat formula in specified range.
    The formula's range automatically increments for each row and column in the range, starting with the upper left cell. 
    For example, if cell B1 has the formula =FLOOR(A1*PI()), while cell D6 has the formula =FLOOR(C6*PI()).
    """
    formula: str = Field(description="Formula to be repeated in range. Should start with '='")
    range_name: str = Field(description="The name of the range, where formula will be repeated.")

    
class AutofillConstant(BaseModel):
    """
    Call this function to autofill constant value in specified range.
    For example, if you need to write value "678" into all cells of range "G3:G8", call this function with constant "678" and range_name "G3:G8".
    """
    constant: str = Field(description="Constant value to be written in a range (not only one cell).")
    range_name: str = Field(description="The name of the range to be autofilled with constant value.")
    
    
class AutofillDelta(BaseModel):
    """
    Call this function to autofill values in specified range with a constant difference.
    For example, if you need to write values "1 2 3 4 5 6" into range "G3:G8", call this function with 
    - fisrt_value: "1"
    - second_value: "2"
    - range_name "G3:G8".
    This will work because all elements in "1 2 3 4 5 6" have constant difference.
    """
    first_value: str = Field(description="First value to be written in a range.")
    second_value: str = Field(description="Sencond value to be written in a range.")
    range_name: str = Field(description="The name of the range to be autofilled with constant difference between values.")
    
class GetStockData(BaseModel):
    """
    Call this function to write stock data into specified cell.
    You have access to the following stock identifiers (stock_id):
    - "SBER" for SberBank or Sber
    - "YNDX" for Yandex 
    - "GAZP" for Gazprom
    Data will occupy two columns: one for dates, one for prices.
    """
    stock_id: str = Field(description="Identifier of a stock.")
    start_dt: str = Field(description="Start date in form yyyy-mm-dd")
    end_dt: str = Field(description="End date in form yyyy-mm-dd")
    cell_id: str = Field(description="ID of the target cell in A1 notation.")

# ------------------------------------------------- Tools -----------------------------------------------------------    

class WriteMarkdownTableTool:
    def __init__(self, gapi):
        self.gapi = gapi

    def __call__(self, range_name, markdown_table):
        matrix = markdown_to_matrix(markdown_table)
        self.gapi.write_values(range_name, matrix)
        return f"The data from the table {markdown_table} has been written into the cell {range_name}"
    
    
class WriteValueTool:
    def __init__(self, gapi):
        self.gapi = gapi

    def __call__(self, cell_id, value):
        self.gapi.write_values(cell_id, [[str(value).replace("'", "\"")]])
        return f"Value {value} has been written into the cell {cell_id}"
    
    
class CreatePivotTableTool:
    def __init__(self, gapi):
        self.gapi = gapi

    def __call__(self, source_range, target_cell, rows, columns, values):
        values = [(value.get('column_index'), value.get('aggregation_function')) for value in values]
        self.gapi.create_pivot_table(source_range, target_cell, rows=rows, columns=columns, values=values)
        decs = describe_pivot_in_cell(self.gapi, target_cell)
        return f"Created pivot table from data in {source_range} with rows {rows}, columns {columns} and values {values}\n" + decs
    
    
class DrawChartTool:
    def __init__(self, gapi):
        self.gapi = gapi

    def __call__(self, chart_type, target_cell, domain_range, chart_series, title, x_title, y_title):
        self.gapi.create_chart(chart_type, target_cell, domain_range, chart_series, title=title, botton_axis_title=x_title, left_axis_title=y_title)
        return f"Created {chart_type} chart from data in {chart_series} with domain in {domain_range}"
    
    
class RepeatFormulaTool:
    def __init__(self, gapi):
        self.gapi = gapi

    def __call__(self, formula, range_name):
        self.gapi.repeat_formula(formula, range_name)
        return f"Repeated formula {formula} in range {range_name}"
    
    
class AutofillConstantTool:
    def __init__(self, gapi):
        self.gapi = gapi

    def __call__(self, constant, range_name):
        # TODO: clear range_name
        self.gapi.write_values(range_name.split(':')[0], [[constant]])
        self.gapi.autofill(range_name)
        return f"Constant value {constant} was written in every cell of range {range_name}"
    
    
class AutofillDeltaTool:
    def __init__(self, gapi):
        self.gapi = gapi

    def __call__(self, first_value, second_value, range_name):
        # TODO: clear range_name
        first_cell = range_name.split(':')[0]
        start_col = ''.join(filter(str.isalpha, first_cell))
        start_row = ''.join(filter(str.isdigit, first_cell))
        start_row = str(int(start_row)+1)
        second_cell = start_col + start_row
        self.gapi.write_values(first_cell, [[first_value]])
        self.gapi.write_values(second_cell, [[second_value]])
        self.gapi.autofill(range_name)
        return f"Values were written into range {range_name}"
    

class GetStockDataTool:
    def __init__(self, gapi):
        self.gapi = gapi

    def __call__(self, stock_id, start_dt, end_dt, cell_id):
        res = requests.get(f'https://iss.moex.com/iss/engines/stock/markets/shares/securities/{stock_id}/candles.json?from={start_dt}&till={end_dt}&interval=24').json()
        prices = [tmp[4]/tmp[5] for tmp in res['candles']['data']]
        dates = [tmp[6][:10] for tmp in res['candles']['data']]
        matrix = [[date, price] for date, price in zip(dates, prices)]
        response = self.gapi.write_values(cell_id, matrix)
        result_range = response['updatedRange'].split('!')[1]
        return \
        f"""
        Stock data has been written to range: {result_range}
        First column contains dates, second column contains prices.
        Table does not contain header.
        """

        

# Prompts

In [18]:
# Annotator

annotation_prompt = \
"""
You are a professional data annotator
You are given a markdown table with columns names, sample data and user query
You should find out if this table is in a usual relation database form or in a form of dictionary and return json with information about this table

Example of a dictionary:
| cats | 5 |
| --- | --- |
| dogs | 4 |
| all | 9 |

Output for dictionary:
{{
    "table_type" : "dictionary",
    "content": [],
    "table_description": "Table with information about cats and dogs"
}}

Example of a relation database:
| price | amount |
| --- | --- |
| 10.9 | 490 |
| 0.23 | 91 |

Output for relation database:
{{
    "table_type" : "relational",
    "content": [
        {{
            "name": "price",
            "description": "Price",
            "type": "number"
        }},
        {{
            "name": "amount",
            "description": "Amount",
            "type": "number"
        }}
    ],
    "table_description": "Table with information about prices and amounts"
}}

If sample is empty, then return table_type="empty"!

Query: 
{query}

Sample: 
{sample}
"""

class TableType(str, Enum):
    dictionary = "dictionary"
    relational = "relational"
    empty = "empty"
    
class RelationValue(BaseModel):
    name: str
    description: str
    type: str
    
class TableDescription(BaseModel):
    table_type: TableType
    content: Union[None, list[RelationValue]]
    table_description: str

In [19]:
# Planner

planner_context = \
"""
You are an expert data analyst who is tasked with writing a detailed plan for data analysis task in Gooogle Sheets.

User query: 
{query}

Data outline: 
{data_outline}

Data is located in cells {data_range}
"""

In [20]:
# Reflector 

reflector_context = \
"""
You are an expert data analyst who is tasked with writing a detailed plan for data analysis task in Gooogle Sheets.

User query: 
{query}

Data outline: 
{data_outline}

Data is located in cells {data_range}
"""

cell_critic_restriction =\
"""
- All formulas have correct adresses of data ranges in arguments
- Result data ranges on the sheets do not overlap
- All tables and formulas on the sheet are displayes in a human convinient way.
"""

formulas_critic_restriction =\
"""
- Plan use only formulas with correct names, correct number of arguments and correct types of arguments.
"""


In [21]:
# Executor

executor_prompt = \
"""
You are helpful assistant who can work with markdown tables and use Google Sheets API to read and write data
You have a bunch of functions you can call to iteract with Google Sheets API.

User query: 
{query}

Data outline: 
{data_outline}

Data is located in cells {data_range}

Do all steps from the following plan:
{main_plan}
"""

# Agent

In [23]:
gapi = GooogleSheetsApi(credentials=f"creds/credentials.json", user_data=f"creds/authorized_user.json")

tools = \
[
    (WriteMarkdownTable, WriteMarkdownTableTool(gapi)),
    (WriteValue, WriteValueTool(gapi)),
    (CreatePivotTable, CreatePivotTableTool(gapi)),
    (DrawChart, DrawChartTool(gapi)),
    (RepeatFormula, RepeatFormulaTool(gapi)),
    (AutofillConstant, AutofillConstantTool(gapi)),
    (AutofillDelta, AutofillDeltaTool(gapi)),
    (GetStockData, GetStockDataTool(gapi)),
]

mini = OpenAIModel('gpt-4o-mini')
maxi = OpenAIModel('gpt-4o-2024-08-06')

Please visit this URL to authorize this application: https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=919060483740-uh93u5ljntblub534ulmq1p84cegne4j.apps.googleusercontent.com&redirect_uri=http%3A%2F%2Flocalhost%3A51361%2F&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fspreadsheets+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.file&state=UT02cmUtNFZvXfkIWX6O2SyKWx1DeQ&access_type=offline


In [24]:
action_space = \
[
    'Get stock market data for a specific company and time interval and write to specified cell (result table will have two columns!)',
    'Write a markdown table to specified cell',
    'Write a formula or a value to one specified cell',
    'Create a pivot table from existing data',
    'Create a simple chart of one of the following types: line, column, area',
    'Repeat a formula to a range of cells with automatic increment of cell address',
    'Autofill blank range with a constant value',
    'Autofill blank range with a data series with constant difference (for example, 1 2 3 4 5 ...)',
]

In [25]:
def data_reader_hook(config):
    config.print_logs("Data outlining: getting data...")
    config.sample, config.data_range = get_data_outline(config.gapi)
    config.print_logs("Data outlining: data obtained")
    return config.sample != "No data on the sheet"

def data_annotator_post_hook(config):
    config.data_outline = config.data_outline.json().replace("{", "[").replace("}", "]")

In [31]:
data_reader = FreeLambda(
    hook = data_reader_hook,
    output_name = "data_exists",
    print_name='data reader'
)

data_annotator = LLMLambda(
    context=annotation_prompt,
    structure=TableDescription,
    model=mini,
    post_hook=data_annotator_post_hook,
    output_name="data_outline",
    print_name="data annotator"
)

planner = Planner(
    context = planner_context,
    action_space = action_space, # list, RAG sorce, None
    model = maxi, 
    self_consistency_rounds = 2, # int, None
    algorithm = "cot", # "sample", "COT", "TOT-best"
    output_name = "main_plan",
    print_name = "main planner",
    tot_number = 3,
    tot_rounds = 10
)

cell_critic = PlannerDomainReflector(
    context = reflector_context,
    restriction = cell_critic_restriction,
    max_rounds = 3,
    model=mini,
    input_name = "main_plan",
    output_name = "main_plan",
    print_name = "cell critic",
)

formulas_critic = PlannerDomainReflector(
    context = reflector_context,
    restriction = formulas_critic_restriction,
    max_rounds = 3,
    model=mini,
    input_name = "main_plan",
    output_name = "main_plan",
    print_name = "formulas critic",
)

main_executor = ToolsExecutor(
    context = executor_prompt,
    tools = tools,
    model = maxi,
    plan = "main_plan",
    plan_type = "proxy",
    output_name = "main_result",
    print_name = "executor",
)

In [32]:
graph_list = \
[
    ("START", data_reader),
    (data_reader, "data_exists", data_annotator, planner),
    [data_annotator, planner, cell_critic, formulas_critic, main_executor, "END"]
]

In [33]:
agent = Agent(graph=graph_list, gapi=gapi)

# Calling agent

In [34]:
gapi.set_sheet_id("18-7vOhBdfn7VacGL-eVGqQIUT_dD0S1_TDSCcfvhI0A")

In [36]:
query = \
"""
I have data of users clicking on button with two different colors.
Please do a simple A/B test in Google Sheets and show if there is a statistical difference between button colors.
Don't use pivot tables.
"""

In [37]:
agent.invoke(query=query)

In [38]:
print(agent.config.logs)

Data outlining: getting data...
Data outlining: data obtained

---------------------------LLM Lambda data annotator----------------------------
Query: 
You are a professional data annotator
You are given a markdown table with columns names, sample data and user query
You should find out if this table is in a usual relation database form or in a form of dictionary and return json with information about this table

Example of a dictionary:
| cats | 5 |
| --- | --- |
| dogs | 4 |
| all | 9 |

Output for dictionary:
{{
    "table_type" : "dictionary",
    "content": [],
    "table_description": "Table with information about cats and dogs"
}}

Example of a relation database:
| price | amount |
| --- | --- |
| 10.9 | 490 |
| 0.23 | 91 |

Output for relation database:
{{
    "table_type" : "relational",
    "content": [
        {{
            "name": "price",
            "description": "Price",
            "type": "number"
        }},
        {{
            "name": "amount",
            "desc