In [None]:
from typing import Optional, Union, Dict, Any

import polars as pl

from vnpy.app.factor_maker.template import FactorTemplate
from vnpy.trader.constant import Interval

In [2]:
class OPEN(FactorTemplate):
    factor_name = 'open'
    dependencies_factor = []
    freq = Interval.MINUTE
    def __init__(self, setting, **kwargs):
        super().__init__(setting, **kwargs)
    def __init_dependencies__(self):
        pass
    
    def calculate(self, input_data: Optional[Union[pl.DataFrame, Dict[str, Any]]], *args, **kwargs) -> Any:
        return input_data['open']
    def calculate_polars(self, input_data: pl.DataFrame, *args, **kwargs) -> Any:
        pass

In [3]:
class MA(FactorTemplate):
    factor_name = 'ma'
    dependencies_factor = []
    freq = Interval.MINUTE
    def __init__(self, setting, window: int = None):
        super().__init__(setting, window=window)
        
    def __init_dependencies__(self):
        self.open = OPEN({})
        setattr(self, 'dependencies_factor', [self.open])
    
    def calculate(self, input_data: Optional[Union[pl.DataFrame, Dict[str, pl.DataFrame]]], *args, **kwargs) -> Any:
        """
        Calculate the rolling mean for all columns in the input data.
        
        Parameters:
        input_data (Optional[Union[pl.DataFrame, Dict[str, pl.DataFrame]]]): Input data with symbols as columns.
        
        Returns:
        pl.DataFrame: DataFrame with the rolling mean for each symbol.
        """
        if isinstance(input_data, dict):
            df = input_data.get(self.open.factor_key)
        elif isinstance(input_data, pl.DataFrame):
            df = input_data
        else:
            raise ValueError("Invalid input_data format. Expected pl.DataFrame or Dict[str, pl.DataFrame].")
        
        # Check if the input is a valid Polars DataFrame
        if not isinstance(df, pl.DataFrame):
            raise ValueError("The 'open' data must be a Polars DataFrame.")
        
        window = self.params.get_parameter('window')
        if window is None:
            raise ValueError("The rolling window size (window) is not set.")
        
        # Identify columns to calculate rolling mean (exclude 'datetime' column if present)
        columns_to_aggregate = [col for col in df.columns if col != 'datetime']
        
        # Calculate the rolling mean for the selected columns
        rolling_means = df.select([pl.col(col).rolling_mean(window).alias(col) for col in columns_to_aggregate])
        
        # Add the datetime column back to the result if it exists
        if 'datetime' in df.columns:
            rolling_means = rolling_means.insert_column(0, df['datetime'])
        
        return rolling_means
    
    def calculate_polars(self, input_data: pl.DataFrame, *args, **kwargs) -> Any:
        pass

In [4]:
class MACD(FactorTemplate):
    factor_name = 'macd'
    dependencies_factor = []
    freq = Interval.MINUTE
    def __init__(self, setting, fast_period: int = None, slow_period: int = None, signal_period: int = None):
        super().__init__(setting=setting, fast_period=fast_period, slow_period=slow_period, signal_period=signal_period)
        
    def __init_dependencies__(self):
        self.ma_fast = MA({}, self.params.get_parameter('fast_period'))
        self.ma_slow = MA({}, self.params.get_parameter('slow_period'))
        self.signal_period = self.params.get_parameter('signal_period')
        setattr(self, 'dependencies_factor', [self.ma_fast, self.ma_slow])
        
    
    def calculate(self, input_data: Optional[Union[pl.DataFrame, Dict[str, pl.DataFrame]]], *args, **kwargs) -> Any:
        """
        Calculate MACD line, signal line, and histogram based on input moving averages.

        Parameters:
        input_data (Optional[Union[pl.DataFrame, Dict[str, pl.DataFrame]]]): Input data with columns for MA fast and slow.

        Returns:
        pl.DataFrame: DataFrame with MACD line, signal line, histogram, and datetime column preserved.
        """
        # Ensure input data is a dictionary
        if not isinstance(input_data, dict):
            raise ValueError("Expected input_data to be a dictionary with pre-calculated moving averages.")

        # Retrieve the pre-calculated moving averages
        ma_fast = input_data.get(self.ma_fast.factor_key)
        ma_slow = input_data.get(self.ma_slow.factor_key)

        if ma_fast is None or ma_slow is None:
            raise ValueError("Missing required moving averages (ma_fast or ma_slow) in input_data.")

        # Ensure the moving averages are Polars DataFrames
        if not isinstance(ma_fast, pl.DataFrame) or not isinstance(ma_slow, pl.DataFrame):
            raise ValueError("ma_fast and ma_slow must be Polars DataFrames.")

        # Check for and preserve the datetime column
        datetime_col = None
        if "datetime" in ma_fast.columns and "datetime" in ma_slow.columns:
            datetime_col = ma_fast["datetime"]
            ma_fast = ma_fast.drop("datetime")
            ma_slow = ma_slow.drop("datetime")

        # Calculate MACD line
        macd_line = ma_fast - ma_slow

        # Calculate Signal line using a rolling mean of the MACD line
        signal_line = macd_line.select([pl.col(col).rolling_mean(self.signal_period).alias(col) for col in macd_line.columns])
        

        # Calculate Histogram (MACD line - Signal line)
        histogram = macd_line - signal_line

        # Add the datetime column back to the result if it exists
        if datetime_col is not None:
            histogram = histogram.insert_column(0, datetime_col)

        return histogram
    
    def calculate_polars(self, input_data: pl.DataFrame, *args, **kwargs) -> Any:
        pass

In [5]:
macd = MACD({}, fast_period=5, slow_period=20, signal_period=5)

Created property for parameter: fast_period
Parameter fast_period is set: 5
Created property for parameter: slow_period
Parameter slow_period is set: 20
Created property for parameter: signal_period
Parameter signal_period is set: 5
Created property for parameter: window
Parameter window is set: 5
window is a property
  - Getter is defined
  - Setter is defined
Parameter window is set: 20
Parameter signal_period is updated: 5 -> 5


In [6]:
import numpy as np
import polars as pl
import pandas as pd

# Step 1: Generate Open Data (Simulated Price Data)
date_range = pd.date_range("2024-01-01", periods=200, freq="1min")
raw_data = {
    "open": pl.DataFrame({
        "datetime": date_range,
        "AAPL": np.random.uniform(150, 155, size=200),
        "MSFT": np.random.uniform(300, 305, size=200),
        "GOOG": np.random.uniform(2800, 2810, size=200),
    }),
    "high": pl.DataFrame({
        "datetime": date_range,
        "AAPL": np.random.uniform(155, 160, size=200),
        "MSFT": np.random.uniform(305, 310, size=200),
        "GOOG": np.random.uniform(2810, 2820, size=200),
    }),
    "low": pl.DataFrame({
        "datetime": date_range,
        "AAPL": np.random.uniform(145, 150, size=200),
        "MSFT": np.random.uniform(295, 300, size=200),
        "GOOG": np.random.uniform(2790, 2800, size=200),
    }),
    "close": pl.DataFrame({
        "datetime": date_range,
        "AAPL": np.random.uniform(150, 155, size=200),
        "MSFT": np.random.uniform(300, 305, size=200),
        "GOOG": np.random.uniform(2800, 2810, size=200),
    }),
    "volume": pl.DataFrame({
        "datetime": date_range,
        "AAPL": np.random.randint(1000, 2000, size=200),
        "MSFT": np.random.randint(1000, 2000, size=200),
        "GOOG": np.random.randint(1000, 2000, size=200),
    }),
}

In [7]:
from vnpy.app.factor_maker.backtesting import FactorBacktester
from vnpy.app.factor_maker.optimizer import FactorOptimizer
bt = FactorBacktester(data=raw_data)
opt = FactorOptimizer(backtester=bt, data=raw_data)

In [8]:
opt.add_factor(macd)

In [9]:
opt.factor_data

{'open@noparams': shape: (200, 4)
 ┌─────────────────────┬────────────┬────────────┬─────────────┐
 │ datetime            ┆ AAPL       ┆ MSFT       ┆ GOOG        │
 │ ---                 ┆ ---        ┆ ---        ┆ ---         │
 │ datetime[ns]        ┆ f64        ┆ f64        ┆ f64         │
 ╞═════════════════════╪════════════╪════════════╪═════════════╡
 │ 2024-01-01 00:00:00 ┆ 150.538204 ┆ 303.552723 ┆ 2806.504509 │
 │ 2024-01-01 00:01:00 ┆ 152.142576 ┆ 301.82936  ┆ 2802.370759 │
 │ 2024-01-01 00:02:00 ┆ 151.402797 ┆ 301.626982 ┆ 2802.486362 │
 │ 2024-01-01 00:03:00 ┆ 153.913028 ┆ 300.08127  ┆ 2804.17543  │
 │ 2024-01-01 00:04:00 ┆ 153.765339 ┆ 303.506308 ┆ 2802.840217 │
 │ …                   ┆ …          ┆ …          ┆ …           │
 │ 2024-01-01 03:15:00 ┆ 151.934677 ┆ 302.45119  ┆ 2808.940514 │
 │ 2024-01-01 03:16:00 ┆ 151.846658 ┆ 301.295687 ┆ 2807.810078 │
 │ 2024-01-01 03:17:00 ┆ 154.744704 ┆ 303.800958 ┆ 2807.547788 │
 │ 2024-01-01 03:18:00 ┆ 152.590906 ┆ 303.552692 ┆ 2804.