In [13]:
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from data_manager import LoacalDatasource
from factor_engine import register_factor, FactorEngine, list_factors
from evaluation.engine import EvaluatorEngine

In [None]:
import json
import ipywidgets as widgets
from ipywidgets import interact

In [14]:
# 显示中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体为黑体
plt.rcParams['axes.unicode_minus'] = False  # 解决负号'-'显示为方块的问题
# 过滤警告
warnings.filterwarnings('ignore')

In [15]:
local_datasource = LoacalDatasource(file_path="./data/daily_price.parquet")
df = local_datasource.load_data(start="2020-01-01", end="2026-01-01")

In [16]:
print(df.columns.to_list())

['open', 'high', 'low', 'close', 'volume', 'amount', 'market_cap_float', 'market_cap_total', 'limit_status']


In [17]:
# 因子引擎
factor_engine = FactorEngine()
# 因子评价
evaluator_engine = EvaluatorEngine()

In [None]:
# 使用已经写好的因子
# factor_name = ""
# factor: pd.Series = factor_engine.compute_one(df, factor_name)

# 自定义因子模版
# factor_name = "custom_factor"
# @register_factor(
#     name=factor_name,
#     required_fields=["close"],
# )
# def factor_func(df: pd.DataFrame) -> pd.Series:
#     return df["close"]

factor_name = "liquidity_momentum_deviation"
@register_factor(
    name=factor_name,
    required_fields=["open", "close", "amount"],
    force_update=True,
)
def liquidity_momentum_deviation_factor(df: pd.DataFrame) -> pd.Series:
    def cal(g) -> pd.Series:
        mean = ((np.log(g["close"] / g["open"])) * g["amount"]).median()
        up_power =  ((np.log(g["close"] / g["open"])) * g["amount"] - mean) ** 2
        return up_power
    
    factor = df.groupby("date").apply(cal)
    factor.index = factor.index.droplevel(0)
    return factor

In [19]:
reports = evaluator_engine.evaluate_multi_horizons(
    df=df,
    factor=factor_engine.compute_one(df, factor_name),
    horizons=[1, 5, 10, 20],
    evaluator="common_eval"
)

In [None]:
# for horizon, report in reports.items():
#     print(f"=== 持有期: {horizon} 天 ===")
#     print(json.dumps(report.metrics, indent=4))

# for horizon, report in reports.items():
#     print(f"=== 持有期: {horizon} 天 ===")
#     report.plot_artifacts(show_fig=True)

In [27]:
def show_report(h):
    # 每次更新时，把旧图清理掉，避免复用同一张 figure
    plt.close('all')
    report = reports[h]
    # 打印当前 horizon 的指标
    print(f"=== 持有期: {h} 天 ===")
    print(json.dumps(report.metrics, indent=4))
    print("\n绘制因子评价图...")
    # 这里会重新画一套图
    report.plot_artifacts(show_fig=True)

# horizon 选项来自 reports 的 key
horizons = sorted(reports.keys())

interact(
    show_report,
    h=widgets.Dropdown(
        options=horizons,
        value=horizons[0],
        description="Horizon",
    )
);


interactive(children=(Dropdown(description='Horizon', options=(1, 5, 10, 20), value=1), Output()), _dom_classe…