<a href="https://colab.research.google.com/github/segzee/3mtt-capstone-project/blob/main/my_hedge_fund_trading_team1.0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0. Setup

In [1]:
!pip install -q google-generativeai langchain
!pip install -qU langchain-google-genai

In [2]:
!pip install langgraph



In [3]:
import getpass
import os


def _set_if_undefined(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"Please provide your {var}")


_set_if_undefined("GOOGLE_API_KEY")               # For calling LLM. Get from https://platform.openai.com/
#_set_if_undefined("FINANCIAL_DATASETS_API_KEY")   # For getting financial data. Get from https://financialdatasets.ai

Please provide your GOOGLE_API_KEY··········


In [4]:
import json
import pandas as pd
import requests
import os
from datetime import timedelta
import matplotlib.pyplot as plt

# Import your agent's dependencies
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import END, MessagesState, StateGraph, START
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import JsonOutputParser
from pydantic import BaseModel, Field
from typing import Literal
import yfinance as yf
import pandas as pd
import numpy as np
import time

In [5]:
# We use gemini-1.5-flash, but you can use any LLM
llm =  ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    temperature=0,
    max_tokens=None,
    timeout=60,
    max_retries=2,
)

# 1. Create Market Data Agent

In [6]:
def market_data_agent(state: MessagesState):
    """Responsible for gathering and preprocessing market data"""
    messages = state["messages"]
    params = messages[-1].additional_kwargs

    # Get the historical price data
    historical_data = get_price_data(
        params["ticker"], params["start_date"], params["end_date"]
    )

    # Calculate the trading signals
    signals = calculate_trading_signals(historical_data)

    # Access the current price as a scalar value
    current_price = signals['current_price']
    if isinstance(current_price, pd.Series):
        current_price = current_price.iloc[-1]  # or current_price.item() if it's a single-value Series

    # Access other signal values and handle potential Series:
    sma_5_curr = signals['sma_5_curr']
    if isinstance(sma_5_curr, pd.Series):
        sma_5_curr = sma_5_curr.iloc[-1]

    sma_5_prev = signals['sma_5_prev']
    if isinstance(sma_5_prev, pd.Series):
        sma_5_prev = sma_5_prev.iloc[-1]

    sma_20_curr = signals['sma_20_curr']
    if isinstance(sma_20_curr, pd.Series):
        sma_20_curr = sma_20_curr.iloc[-1]

    sma_20_prev = signals['sma_20_prev']
    if isinstance(sma_20_prev, pd.Series):
        sma_20_prev = sma_20_prev.iloc[-1]


    # Create the market data agent's message using the signals and historical data
    message = HumanMessage(
        content=f"""
        Here are the trading signals for {params["ticker"]}:
        Current Price: ${current_price:.2f}  # Format the scalar value
        SMA 5: {sma_5_curr:.2f}
        SMA 5 Previous: {sma_5_prev:.2f}
        SMA 20: {sma_20_curr:.2f}
        SMA 20 Previous: {sma_20_prev:.2f}
        """,
        name="market_data_agent",
    )

    return {"messages": messages + [message]}

# 2. Create Quant Agent

In [7]:
def quant_agent(state: MessagesState):
    """Analyzes technical indicators and generates trading signals"""
    last_message = state["messages"][-1]

    summary_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are a hedge fund quant / technical analyst.
                You are given trading signals for a stock.
                Analyze the signals and provide a recommendation.
                - signal: bullish | bearish | neutral,
                - confidence: <float between 0 and 1>
                """
            ),
            MessagesPlaceholder(variable_name="messages"),
            (
                "human",
                f"""Based on the trading signals below, analyze the data and provide your assessment.

                Trading Analysis: {last_message.content}

                Only include your trading signal and confidence in the output.
                """
            ),
        ]
    )

    chain = summary_prompt | llm

    result = chain.invoke(state).content
    message = HumanMessage(
        content=f"Here is the trading analysis and my recommendation:{result}",
        name="quant_agent",
    )

    return {"messages": state["messages"] + [message]}


# 3. Create Risk Management Agent

In [8]:
def risk_management_agent(state: MessagesState):
    """Evaluates portfolio risk and sets position limits"""
    portfolio = state["messages"][0].additional_kwargs["portfolio"]
    last_message = state["messages"][-1]

    risk_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are a risk management specialist.
                Your job is to take a look at the trading analysis and
                evaluate portfolio exposure and recommend position sizing.
                Provide the following in your output (not as a JSON):
                - max_position_size: <float greater than 0>,
                - risk_score: <integer between 1 and 10>"""
            ),
            MessagesPlaceholder(variable_name="messages"),
            (
                "human",
                f"""Based on the trading analysis below, provide your risk assessment.

                Trading Analysis: {last_message.content}

                Here is the current portfolio:
                Portfolio:
                Cash: ${portfolio['cash']:.2f}
                Current Position: {portfolio['stock']} shares

                Only include the max position size and risk score in your output.
                """
            ),
        ]
    )
    chain = risk_prompt | llm
    result = chain.invoke(state).content
    message = HumanMessage(
        content=f"Here is the risk management recommendation: {result}",
        name="risk_management",
    )
    return {"messages": state["messages"] + [message]}

# 4. Create Portfolio Management Agent

In [9]:
def portfolio_management_agent(state: MessagesState):
    """Makes final trading decisions and generates orders"""
    portfolio = state["messages"][0].additional_kwargs["portfolio"]
    last_message = state["messages"][-1]

    portfolio_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are a portfolio manager making final trading decisions.
                Your job is to make a trading decision based on the risk management data.
                Provide the following in your output:
                - "action": "buy" | "sell" | "hold",
                - "quantity": <positive integer>
                Only buy if you have available cash.
                The quantity that you buy must be less than or equal to the max position size.
                Only sell if you have shares in the portfolio to sell.
                The quantity that you sell must be less than or equal to the current position."""
            ),
            MessagesPlaceholder(variable_name="messages"),
            (
                "human",
                f"""Based on the risk management data below, make your trading decision.

                Risk Management Data: {last_message.content}

                Here is the current portfolio:
                Portfolio:
                Cash: ${portfolio['cash']:.2f}
                Current Position: {portfolio['stock']} shares

                Only include the action and quantity.

                Remember, the action must be either buy, sell, or hold.
                You can only buy if you have available cash.
                You can only sell if you have shares in the portfolio to sell.
                """
            ),
        ]
    )

    chain = portfolio_prompt | llm
    result = chain.invoke(state).content
    return {"messages": [HumanMessage(content=result, name="portfolio_management")]}

# 5. Create Agent Graph

In [10]:
# Define the new workflow
workflow = StateGraph(MessagesState)

# Add nodes
workflow.add_node("market_data_agent", market_data_agent)
workflow.add_node("quant_agent", quant_agent)
workflow.add_node("risk_management_agent", risk_management_agent)
workflow.add_node("portfolio_management_agent", portfolio_management_agent)

# Define the workflow
workflow.add_edge(START, "market_data_agent")
workflow.add_edge("market_data_agent", "quant_agent")
workflow.add_edge("quant_agent", "risk_management_agent")
workflow.add_edge("risk_management_agent", "portfolio_management_agent")
workflow.add_edge("portfolio_management_agent", END)

app = workflow.compile()

In [11]:
# Update the run_agent function to include portfolio state
def run_agent(ticker: str, start_date: str, end_date: str, portfolio: dict):
    final_state = app.invoke(
        {
            "messages": [
                HumanMessage(
                    content="Make a trading decision based on the provided data.",
                    additional_kwargs={
                        "ticker": ticker,
                        "start_date": start_date,
                        "end_date": end_date,
                        "portfolio": portfolio
                    },
                )
            ]
        },
        config={"configurable": {"thread_id": 42}},
    )
    return final_state["messages"][-1].content


# 6. Get Stock Price and Trading Signals

In [12]:
def calculate_trading_signals(historical_data: pd.DataFrame) -> dict:
    """Calculate trading signals based on SMA crossover strategy."""

    # Check if there's enough data for calculations
    if len(historical_data) < 20:  # Need at least 20 data points for SMA20
        return {
            "current_price": np.nan,
            "sma_5_curr": np.nan,
            "sma_5_prev": np.nan,
            "sma_20_curr": np.nan,
            "sma_20_prev": np.nan,
        }

    # Calculate SMAs
    sma_5 = historical_data["close"].rolling(window=5).mean()
    sma_20 = historical_data["close"].rolling(window=20).mean()

    # Handle cases where SMA series does not have enough values
    sma_5_curr = sma_5.iloc[-1] if len(sma_5) > 0 else np.nan
    sma_5_prev = sma_5.iloc[-2] if len(sma_5) > 1 else np.nan

    sma_20_curr = sma_20.iloc[-1] if len(sma_20) > 0 else np.nan
    sma_20_prev = sma_20.iloc[-2] if len(sma_20) > 1 else np.nan

    # Ensure current price is a scalar value
    current_price = historical_data["close"].iloc[-1] if len(historical_data) > 0 else np.nan

    return {
        "current_price": current_price,
        "sma_5_curr": sma_5_curr,
        "sma_5_prev": sma_5_prev,
        "sma_20_curr": sma_20_curr,
        "sma_20_prev": sma_20_prev,
    }

In [13]:
# Function to fetch price data using yfinance
def get_price_data(ticker, start_date, end_date):
    """
    Fetch historical stock price data using yfinance.

    Args:
        ticker (str): Stock ticker symbol (e.g., 'AAPL').
        start_date (str): Start date for data in 'YYYY-MM-DD' format.
        end_date (str): End date for data in 'YYYY-MM-DD' format.

    Returns:
        pd.DataFrame: DataFrame with historical price data.
    """
    # Fetch historical data
    data = yf.download(ticker, start=start_date, end=end_date, interval="1d")

    # Check if data is returned
    if data.empty:
        raise ValueError(f"No price data found for {ticker} between {start_date} and {end_date}.")

    # Rename columns to match original code's naming convention
    data.rename(
        columns={
            "Open": "open",
            "Close": "close",
            "High": "high",
            "Low": "low",
            "Volume": "volume",
        },
        inplace=True,
    )

    # Ensure the index is called "Date"
    data.index.name = "Date"

    # Select only required columns
    df = data[["open", "close", "high", "low", "volume"]]

    return df

# 7. Create Backtester

In [14]:
class Backtester:
    def __init__(self, agent, ticker, start_date, end_date, initial_capital):
        self.agent = agent
        self.ticker = ticker
        self.start_date = start_date
        self.end_date = end_date
        self.initial_capital = initial_capital
        self.portfolio = {"cash": initial_capital, "stock": 0}
        self.portfolio_values = []

    def parse_action(self, agent_output):
        try:
            # Expect JSON output from agent
            import json
            decision = json.loads(agent_output)
            return decision["action"], decision["quantity"]
        except:
            return "hold", 0

    def execute_trade(self, action, quantity, current_price):
        """Validate and execute trades based on portfolio constraints"""
        if action == "buy" and quantity > 0:
            cost = quantity * current_price
            if cost <= self.portfolio["cash"]:
                self.portfolio["stock"] += quantity
                self.portfolio["cash"] -= cost
                return quantity
            else:
                # Calculate maximum affordable quantity
                max_quantity = self.portfolio["cash"] // current_price
                if max_quantity > 0:
                    self.portfolio["stock"] += max_quantity
                    self.portfolio["cash"] -= max_quantity * current_price
                    return max_quantity
                return 0
        elif action == "sell" and quantity > 0:
            quantity = min(quantity, self.portfolio["stock"])
            if quantity > 0:
                self.portfolio["cash"] += quantity * current_price
                self.portfolio["stock"] -= quantity
                return quantity
            return 0
        return 0

    def run_backtest(self):
        dates = pd.date_range(self.start_date, self.end_date, freq="B")

        print("\nStarting backtest...")
        print(f"{'Date':<12} {'Action':<6} {'Quantity':>8} {'Price':>8} {'Cash':>12} {'Stock':>8} {'Total Value':>12}")
        print("-" * 70)

        for current_date in dates:
            lookback_start = (current_date - timedelta(days=30)).strftime("%Y-%m-%d")
            current_date_str = current_date.strftime("%Y-%m-%d")

            agent_output = self.agent(
                ticker=self.ticker,
                start_date=lookback_start,
                end_date=current_date_str,
                portfolio=self.portfolio
            )
            # Introduce a delay after calling the agent to avoid rate limiting
            time.sleep(1)  # Pause execution for 1 second

            action, quantity = self.parse_action(agent_output)
            df = get_price_data(self.ticker, lookback_start, current_date_str)
            # Ensure current_price is a single numeric value
            current_price = df.iloc[-1]['close']
            if isinstance(current_price, pd.Series):
                current_price = current_price.item() # If it's a Series, get the single value

            # Execute the trade with validation
            executed_quantity = self.execute_trade(action, quantity, current_price)

            # Update total portfolio value
            total_value = self.portfolio["cash"] + self.portfolio["stock"] * current_price
            self.portfolio["portfolio_value"] = total_value

            # Log the current state with executed quantity
            print(
                f"{current_date.strftime('%Y-%m-%d'):<12} {action:<6} {executed_quantity:>8} {current_price:>8.2f} "
                f"{self.portfolio['cash']:>12.2f} {self.portfolio['stock']:>8} {total_value:>12.2f}"
            )

            # Record the portfolio value
            self.portfolio_values.append(
                {"Date": current_date, "Portfolio Value": total_value}
            )

    def analyze_performance(self):
        # Convert portfolio values to DataFrame
        performance_df = pd.DataFrame(self.portfolio_values).set_index("Date")

        # Calculate total return
        total_return = (
                           self.portfolio["portfolio_value"] - self.initial_capital
                       ) / self.initial_capital
        print(f"Total Return: {total_return * 100:.2f}%")

        # Plot the portfolio value over time
        performance_df["Portfolio Value"].plot(
            title="Portfolio Value Over Time", figsize=(12, 6)
        )
        plt.ylabel("Portfolio Value ($)")
        plt.xlabel("Date")
        plt.show()

        # Compute daily returns
        performance_df["Daily Return"] = performance_df["Portfolio Value"].pct_change()

        # Calculate Sharpe Ratio (assuming 252 trading days in a year)
        mean_daily_return = performance_df["Daily Return"].mean()
        std_daily_return = performance_df["Daily Return"].std()
        sharpe_ratio = (mean_daily_return / std_daily_return) * (252 ** 0.5)
        print(f"Sharpe Ratio: {sharpe_ratio:.2f}")

        # Calculate Maximum Drawdown
        rolling_max = performance_df["Portfolio Value"].cummax()
        drawdown = performance_df["Portfolio Value"] / rolling_max - 1
        max_drawdown = drawdown.min()
        print(f"Maximum Drawdown: {max_drawdown * 100:.2f}%")

        return performance_df


# 8. Run the Backtest

In [15]:
# Define parameters
ticker = "AAPL"  # Example ticker symbol
start_date = "2024-01-01"  # Adjust as needed
end_date = "2024-11-23"  # Adjust as needed
initial_capital = 100000  # $100,000

# Create an instance of Backtester
backtester = Backtester(
    agent=run_agent,
    ticker=ticker,
    start_date=start_date,
    end_date=end_date,
    initial_capital=initial_capital,
)

# Run the backtesting process
backtester.run_backtest()
performance_df = backtester.analyze_performance()


Starting backtest...
Date         Action Quantity    Price         Cash    Stock  Total Value
----------------------------------------------------------------------


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed


2024-01-01   hold          0   192.53    100000.00        0    100000.00


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed


2024-01-02   hold          0   192.53    100000.00        0    100000.00


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed


2024-01-03   hold          0   185.64    100000.00        0    100000.00


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed


2024-01-04   hold          0   184.25    100000.00        0    100000.00


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed


2024-01-05   hold          0   181.91    100000.00        0    100000.00


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed

2024-01-08   hold          0   181.18    100000.00        0    100000.00



[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed


2024-01-09   hold          0   185.56    100000.00        0    100000.00


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed


2024-01-10   hold          0   185.14    100000.00        0    100000.00


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed


2024-01-11   hold          0   186.19    100000.00        0    100000.00


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed


2024-01-12   hold          0   185.59    100000.00        0    100000.00




ResourceExhausted: 429 Resource has been exhausted (e.g. check quota).