In [2]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.colors as pc
import yfinance as yf
import pandas as pd
import plotly.io as pio
from typing import TypedDict, Annotated, Optional
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage

# Define the state structure
class GraphState(TypedDict):
    query: str
    plot_type: Optional[str]
    ticker: Optional[str]
    plot_json: Optional[str]
    response: Optional[str]

# Data fetching functions
def fetch_stock_data(ticker, period="1y"):
    stock = yf.Ticker(ticker)
    return stock.history(period=period)

def fetch_balance(ticker, tp="Annual"):
    ticker_obj = yf.Ticker(ticker)
    bs = ticker_obj.balance_sheet if tp == "Annual" else ticker_obj.quarterly_balance_sheet
    return bs.loc[:, bs.isna().mean() < 0.5]

# Plotting functions
def plot_candles_stick(df, title=""):
    fig = go.Figure(data=[go.Candlestick(x=df.index,
                open=df['Open'],
                high=df['High'],
                low=df['Low'],
                close=df['Close'])])
    fig.update_layout(title=title)
    return fig

def plot_balance(df, ticker="", currency=""):
    df.columns = pd.to_datetime(df.columns).strftime('%b %d, %Y')
    components = {
        'Total Assets': {'color': 'forestgreen', 'name': 'Assets'},
        'Stockholders Equity': {'color': 'CornflowerBlue', 'name': "Stockholder's Equity"},
        'Total Liabilities Net Minority Interest': {'color': 'tomato', 'name': "Total Liabilities"},
    }
    
    fig = go.Figure()
    for component in components:
        if component == 'Total Assets':
            fig.add_trace(go.Bar(
                x=[df.columns, ['Assets'] * len(df.columns)],
                y=df.loc[component],
                name=components[component]['name'],
                marker=dict(color=components[component]['color'])
            ))
        else:
            fig.add_trace(go.Bar(
                x=[df.columns, ['L+E'] * len(df.columns)],
                y=df.loc[component],
                name=components[component]['name'],
                marker=dict(color=components[component]['color'])
            ))

    offset = 0.03 * df.loc['Total Assets'].max()
    for i, date in enumerate(df.columns):
        fig.add_annotation(
            x=[date, "Assets"],
            y=df.loc['Total Assets', date] / 2,
            text=str(round(df.loc['Total Assets', date] / 1e9, 1)) + 'B',
            showarrow=False,
            font=dict(size=12, color="black"),
            align="center"
        )
        percentage = round((df.loc['Total Liabilities Net Minority Interest', date] / df.loc['Total Assets', date]) * 100, 1)
        fig.add_annotation(
            x=[date, "L+E"],
            y=df.loc['Stockholders Equity', date] + df.loc['Total Liabilities Net Minority Interest', date] / 2,
            text=str(percentage) + '%',
            showarrow=False,
            font=dict(size=12, color="black"),
            align="center"
        )
        if i > 0:
            percentage = round((df.loc['Total Assets'].iloc[i] / df.loc['Total Assets'].iloc[i - 1] - 1) * 100, 1)
            sign = '+' if percentage >= 0 else ''
            fig.add_annotation(
                x=[date, "Assets"],
                y=df.loc['Total Assets', date] + offset,
                text=sign + str(percentage) + '%',
                showarrow=False,
                font=dict(size=12, color="black"),
                align="center"
            )

    fig.update_layout(
        barmode='stack',
        title=f'Accounting Balance: {ticker}',
        xaxis_title='Year',
        yaxis_title=f'Amount (in {currency})',
        legend_title='Balance components',
    )
    return fig

def plot_assets(df, ticker="", currency=""):
    assets = {
        'Current Assets': {
            'Cash Cash Equivalents And Short Term Investments': {},
            'Receivables': {},
            'Prepaid Assets': None,
            'Inventory': {},
            'Hedging Assets Current': None,
            'Other Current Assets': None
        },
        'Total Non Current Assets': {
            'Net PPE': {},
            'Goodwill And Other Intangible Assets': {},
            'Investments And Advances': {},
            'Investment Properties': None,
            'Other Non Current Assets': None
        }
    }

    fig = make_subplots(
        rows=1, cols=2,
        shared_yaxes=True,
        horizontal_spacing=0.05,
        subplot_titles=['Current Assets', 'Non-Current Assets']
    )

    colors = pc.sequential.Blugrn[::-1]
    i = 0
    for component in assets['Current Assets']:
        if component in df.index:
            fig.add_trace(go.Bar(
                x=df.columns,
                y=df.loc[component],
                name=component,
                marker=dict(color=colors[i]),
                legendgroup='Current Assets',
                showlegend=True
            ), row=1, col=1)
            i += 1

    colors = pc.sequential.Purp[::-1]
    i = 0
    for component in assets['Total Non Current Assets']:
        if component in df.index:
            fig.add_trace(go.Bar(
                x=df.columns,
                y=df.loc[component],
                name=component,
                marker=dict(color=colors[i]),
                legendgroup='Non-current Assets',
                showlegend=True
            ), row=1, col=2)
            i += 1

    offset = 0.03 * max(df.loc['Current Assets'].max(), df.loc['Total Non Current Assets'].max())
    for i, date in enumerate(df.columns):
        fig.add_annotation(
            x=date,
            y=df.loc['Current Assets', date] + offset,
            text=str(round(df.loc['Current Assets', date] / 1e9, 1)) + 'B',
            showarrow=False,
            font=dict(size=12, color="black"),
            align="center",
            row=1, col=1
        )
        fig.add_annotation(
            x=date,
            y=df.loc['Total Non Current Assets', date] + offset,
            text=str(round(df.loc['Total Non Current Assets', date] / 1e9, 1)) + 'B',
            showarrow=False,
            font=dict(size=12, color="black"),
            align="center",
            row=1, col=2
        )

    fig.update_layout(
        barmode='stack',
        title=f'Assets: {ticker}',
        xaxis1=dict(title='Date', type='date', tickvals=df.columns),
        xaxis2=dict(title='Date', type='date', tickvals=df.columns),
        yaxis_title=f'Amount (in {currency})',
        legend_title='Asset Components',
    )
    return fig

# LangGraph nodes
def parse_query(state: GraphState) -> GraphState:
    """Parse the user query to determine plot type and ticker"""
    query = state["query"].lower()
    ticker = query.split()[-1].upper()
    
    if "candlestick chart" in query:
        return {"plot_type": "candlestick", "ticker": ticker}
    elif "balance sheet" in query:
        return {"plot_type": "balance", "ticker": ticker}
    elif "assets" in query:
        return {"plot_type": "assets", "ticker": ticker}
    else:
        return {"plot_type": None, "ticker": None}

def generate_plot(state: GraphState) -> GraphState:
    """Generate the appropriate plot based on the parsed query"""
    if not state["plot_type"] or not state["ticker"]:
        return {"response": "I can generate candlestick charts, balance sheets, or assets visualizations. Please specify what you'd like to see (e.g., 'Show me a candlestick chart for AAPL')"}
    
    ticker = state["ticker"]
    plot_type = state["plot_type"]
    
    try:
        if plot_type == "candlestick":
            df = fetch_stock_data(ticker)
            fig = plot_candles_stick(df, title=f"{ticker} Candlestick Chart")
        elif plot_type == "balance":
            df = fetch_balance(ticker)
            fig = plot_balance(df, ticker=ticker, currency="USD")
        elif plot_type == "assets":
            df = fetch_balance(ticker)
            fig = plot_assets(df, ticker=ticker, currency="USD")
        
        plot_json = fig.to_json()
        return {"plot_json": plot_json}
    
    except Exception as e:
        return {"response": f"Error generating plot: {str(e)}"}

def format_response(state: GraphState) -> GraphState:
    """Format the final response"""
    if state.get("plot_json"):
        return {"response": state["plot_json"]}
    elif state.get("response"):
        return {"response": state["response"]}
    else:
        return {"response": "Something went wrong while processing your request"}

# Build the graph


In [3]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.colors as pc
import yfinance as yf
import pandas as pd
import plotly.io as pio
from typing import TypedDict, Annotated, Optional
from langgraph.graph import StateGraph, END


class GraphState(TypedDict):
    query: str
    plot_type: Optional[str]
    ticker: Optional[str]
    plot_json: Optional[str]
    response: Optional[str]

# LangGraph nodes (unchanged from previous code)
def parse_query(state: GraphState) -> GraphState:
    query = state["query"].lower()
    ticker = query.split()[-1].upper()
    if "candlestick chart" in query:
        return {"plot_type": "candlestick", "ticker": ticker}
    elif "balance sheet" in query:
        return {"plot_type": "balance", "ticker": ticker}
    elif "assets" in query:
        return {"plot_type": "assets", "ticker": ticker}
    else:
        return {"plot_type": None, "ticker": None}

def generate_plot(state: GraphState) -> GraphState:
    if not state["plot_type"] or not state["ticker"]:
        return {"response": "I can generate candlestick charts, balance sheets, or assets visualizations. Please specify what you'd like to see (e.g., 'Show me a candlestick chart for AAPL')"}
    
    ticker = state["ticker"]
    plot_type = state["plot_type"]
    
    try:
        if plot_type == "candlestick":
            df = fetch_stock_data(ticker)
            fig = plot_candles_stick(df, title=f"{ticker} Candlestick Chart")
        elif plot_type == "balance":
            df = fetch_balance(ticker)
            fig = plot_balance(df, ticker=ticker, currency="USD")
        elif plot_type == "assets":
            df = fetch_balance(ticker)
            fig = plot_assets(df, ticker=ticker, currency="USD")
        
        plot_json = fig.to_json()
        return {"plot_json": plot_json}
    except Exception as e:
        return {"response": f"Error generating plot: {str(e)}"}

def format_response(state: GraphState) -> GraphState:
    if state.get("plot_json"):
        return {"response": state["plot_json"]}
    elif state.get("response"):
        return {"response": state["response"]}
    else:
        return {"response": "Something went wrong while processing your request"}


# Build and compile the graph
workflow = StateGraph(GraphState)
workflow.add_node("parse_query", parse_query)
workflow.add_node("generate_plot", generate_plot)
workflow.add_node("format_response", format_response)
workflow.set_entry_point("parse_query")
workflow.add_edge("parse_query", "generate_plot")
workflow.add_edge("generate_plot", "format_response")
workflow.add_edge("format_response", END)
app = workflow.compile()

def chatbot_response(user_query: str) -> str:
    initial_state = {"query": user_query}
    result = app.invoke(initial_state)
    return result["response"]

# Example usage
if __name__ == "__main__":

    queries = [
        "Show me a candlestick chart for AAPL",
        "Show me the balance sheet for MSFT",
        "Show me the assets for GOOGL"
    ]
    
    for query in queries:
        response = chatbot_response(query)
        print(f"Query: {query}")
        if isinstance(response, str) and "{" in response:
            try:
                fig = pio.from_json(response)
                fig.show()
            except:
                print(f"Response: {response}")
        else:
            print(f"Response: {response}")
        print("-" * 50)

Query: Show me a candlestick chart for AAPL


--------------------------------------------------
Query: Show me the balance sheet for MSFT


--------------------------------------------------
Query: Show me the assets for GOOGL


--------------------------------------------------
