## Test `sheet_analysis` on Housing Data
This notebook serves to display the capabilities of the sheet analysis agent.

In [1]:
import pandas as pd

from sheet_analysis.llm.agent import AnalysisAgent

### Definitions
Define path to file to analyze, the model to use, and the generation temperature. Note: `model_name` and `temperature` have defaults, but we define them here for demonstration.

In [2]:
filepath = "../data/housing.csv"
model_name = "llama3.2"
temperature = 0.0

### Preview data
This is not necessary for the agent, but we load the data and preview it to get an idea of what it looks like.

In [3]:
data = pd.read_csv(filepath)
data.head()

Unnamed: 0,Index,Street Address,City,Square Footage,Price
0,1,847 Oakdale Drive,Centennial,1568,460200
1,2,421 Maple Street,Centennial,4099,1152700
2,3,189 Pineview Avenue,Centennial,4595,1264300
3,4,623 Cedar Court,Centennial,2529,718600
4,5,842 Walnut Road,Centennial,1321,360600


### Initialize agent

In [4]:
agent = AnalysisAgent(
    filepath=filepath,
    model_name=model_name,
    temperature=temperature,
)

### Ask the agent a basic query
We prompt it for an average home price with a filter. The output is correct, and we see the code that was generated.

In [5]:
prompt = "What is the average home price in Parker?"
response = agent.invoke(prompt)
print(response)

The average home price in Parker is approximately $788,777.78.

Here is the code I generated to get that:

```python

import pandas as pd
import matplotlib

def func(data):
    # Filter the DataFrame to only include rows where 'City' is 'Parker'
    parker_data = data[data['City'] == 'Parker']
    
    # Calculate and return the average home price
    return parker_data['Price'].mean()

```


### Ask the agent to plot something
We prompt it to generate a figure that has filtered data (can be multiple subplots or one plot with a legend). If you look at "notebooks/figures/output.png", you can see the generated figure. Note: stop running at this cell if you want to view the figure.

In [6]:
prompt = "Show a scatter plot of home price vs square feet broken down by town."
response = agent.invoke(prompt)
print(response)

I generated a figure for you. It is located at figures/output.png

Here is the code I generated to get that:

```python

import pandas as pd
import matplotlib.pyplot as plt

def func(data):
    # Group the data by city and create a scatter plot of price vs square footage
    fig, ax = plt.subplots(figsize=(10, 8))
    for city in data['City'].unique():
        city_data = data[data['City'] == city]
        ax.scatter(city_data['Square Footage'], city_data['Price'], label=city)
    
    # Set the title and labels
    ax.set_title('Home Price vs Square Footage by City')
    ax.set_xlabel('Square Footage')
    ax.set_ylabel('Price')
    
    # Add a legend
    ax.legend()
    
    # Save the figure to a file
    plt.savefig('figures/output.png')
    
    # Close the figure before returning
    plt.close()
    
    return None

```


### Ask the agent another query.
We prompt it for home counts broken down by town, even though the spreadsheet has a "City" header. We see that it did not trip up the model. Also, when this was run, the model also generated code to plot a bar chart with these counts even though it wasn't asked.

In [7]:
prompt = "How many homes are in each town?"
response = agent.invoke(prompt)
print(response)

Here is the answer to your question:

There are 9 homes in Parker, 6 homes in Centennial, 6 homes in Lone Tree, 5 homes in Lakewood, and 4 homes in Arvada.

I generated a figure for you. It is located at figures/output.png

Here is the code I generated to get that:

```python

import pandas as pd
import matplotlib

def func(data):
    # Group by city and count the number of rows (homes) in each group
    result = data['City'].value_counts()
    
    # Save the figure to a file
    fig, ax = matplotlib.pyplot.subplots(figsize=(10, 6))
    ax.bar(result.index, result.values)
    ax.set_xlabel('Town')
    ax.set_ylabel('Number of Homes')
    ax.set_title('Number of Homes in Each Town')
    matplotlib.pyplot.savefig('figures/output.png')
    matplotlib.pyplot.close(fig)
    
    return result

```
