# Strategy Backtest Framework

## 1. Import Library

In [5]:
import dai
import random
import pandas as pd
import numpy as np
import bigcharts
from bigcharts import opts

## 2. Strategy Settings

In [6]:
strategy_param_dict = {
    "sd":'2020-12-31',
    "ed":'2026-02-13',
    "sql_factor":f"""
    SELECT
        date,
        instrument,
        ps_ttm AS factor
    FROM cn_stock_prefactors
    WHERE sw2021_level1 IN ('110000', '210000', '220000', '230000', '240000', '270000', '280000', '310000', '320000', '330000', '340000', '350000', '360000', '370000', '410000', '420000', '430000', '440000', '450000', '460000', '470000', '480000', '490000', '510000', '630000', '650000', '710000', '720000', '730000', '740000', '760000', '770000')
    AND instrument NOT LIKE '%BJ%'
    QUALIFY c_group_pct_rank(sw2021_level2, ps_ttm) < 0.20 
    AND c_group_pct_rank(sw2021_level2, total_operating_revenue_yoy_lf) > 0.67
    AND net_profit_yoy_lf > 0
    """,
    "stock_num":10,
    "rebalance":"每月",   # 可选：每日, 每周, 每月, 每季, 每年
    "weighting":"等权重", # 可选：等权重、市值权重、因子值权重、因子排名权重
    "capital_base":1000000
}

## 3. Strategy Backtesting Framework

In [7]:
class StrategyBacktest:

    def __init__(self, strategy_param_dict):
        self.sd = strategy_param_dict["sd"]
        self.ed = strategy_param_dict["ed"]
        self.df_trade = self.get_df_trade(strategy_param_dict)
        self.df_account, self.df_position, self.df_trades = self.backtest_strategy(self.df_trade, strategy_param_dict)
        self.df_nav = self.backtest_output()

    def get_df_trade(self, strategy_param_dict):

        if strategy_param_dict["weighting"] == "市值权重":
            sql_weight = "float_market_cap / c_sum(float_market_cap)"
        elif strategy_param_dict["weighting"] == "因子值权重":
            sql_weight = "score / c_sum(score)"
        elif strategy_param_dict["weighting"] == "因子排名权重":
            sql_weight = "score_rank / c_sum(score_rank)"
        else:
            sql_weight = "1 / c_sum(1)"

        if strategy_param_dict["rebalance"] == "每周":
            sql_rebalance = "is_week_end_trade = 1"
        elif strategy_param_dict["rebalance"] == "每月":
            sql_rebalance = "is_month_end_trade = 1"
        elif strategy_param_dict["rebalance"] == "每季":
            sql_rebalance = "is_quarter_end_trade = 1"
        elif strategy_param_dict["rebalance"] == "每年":
            sql_rebalance = "is_year_end_trade = 1"
        else:
            sql_rebalance = "1=1"

        sql_strategy = f"""
        WITH
        data_strategy AS (
            {strategy_param_dict["sql_factor"]}
        ),
        data_filter AS (
            SELECT
                date,
                instrument,
                factor AS score,
                c_rank(factor) AS score_rank,
            FROM data_strategy
            QUALIFY score_rank <= {strategy_param_dict["stock_num"]}
        ),
        data_date AS (
            SELECT
                date,
                instrument,
                score, 
                score_rank, 
                {sql_weight} AS position, 
            FROM data_filter JOIN cn_stock_valuation USING (date, instrument) JOIN mldt_cn_stock_calendar_daily USING (date)
            WHERE {sql_rebalance}
        )
        SELECT *
        FROM data_date
        ORDER BY date, score_rank
        """

        df_strategy = dai.query(sql_strategy, filters={"date":[strategy_param_dict["sd"], strategy_param_dict["ed"]]}).df()

        df_date = dai.query("""
        SELECT 
            date,
            LAG(date, -1) OVER (PARTITION BY market_code ORDER BY date) AS date_trade
        FROM all_trading_days 
        WHERE market_code = 'CN'
        """).df()

        df_trade = pd.merge(df_date, df_strategy, how='inner', on='date')

        return df_trade
    
    def backtest_strategy(self, df_trade, strategy_param_dict):

        BUY_COST  = 0.0003 
        SELL_COST = 0.0013 
        LOT = 100          

        sd = strategy_param_dict["sd"]
        ed = strategy_param_dict["ed"]
        capital_base = float(strategy_param_dict["capital_base"])


        df_days = dai.query("""
            SELECT date AS date_trade
            FROM all_trading_days
            WHERE market_code = 'CN'
            ORDER BY date
        """).df()
        df_days["date_trade"] = pd.to_datetime(df_days["date_trade"]).dt.normalize()
        all_days = pd.to_datetime(pd.Index(df_days["date_trade"].tolist())).normalize()


        df_trade = df_trade.copy()
        df_trade["date_trade"] = pd.to_datetime(df_trade["date_trade"]).dt.normalize()

        universe = sorted(df_trade["instrument"].unique().tolist())

        target_on_rebalance = (
            df_trade.pivot_table(index="date_trade", columns="instrument", values="position", aggfunc="sum")
            .sort_index()
        )

        target_w = target_on_rebalance.reindex(all_days)
        rebalance_days = target_w.notna().any(axis=1)
        target_w.loc[rebalance_days] = target_w.loc[rebalance_days].fillna(0.0) 
        target_w = target_w.ffill().fillna(0.0)


        df_real = dai.query("""
            SELECT date, instrument, open, high, low, close, volume
            FROM cn_stock_real_bar1d
        """, filters={"date": [sd, ed], "instrument": universe}).df()

        df_adj = dai.query("""
            SELECT date, instrument, open, close
            FROM cn_stock_bar1d
        """, filters={"date": [sd, ed], "instrument": universe}).df()

        df_real["date_trade"] = pd.to_datetime(df_real["date"]).dt.normalize()
        df_adj["date_trade"]  = pd.to_datetime(df_adj["date"]).dt.normalize()

        vol_exec  = df_real.pivot_table(index="date_trade", columns="instrument", values="volume", aggfunc="last").sort_index()
        high_exec = df_real.pivot_table(index="date_trade", columns="instrument", values="high",   aggfunc="last").sort_index()
        low_exec  = df_real.pivot_table(index="date_trade", columns="instrument", values="low",    aggfunc="last").sort_index()

        real_close = df_real.pivot_table(index="date_trade", columns="instrument", values="close", aggfunc="last").sort_index()

        adj_open  = df_adj.pivot_table(index="date_trade", columns="instrument", values="open",  aggfunc="last").sort_index()
        adj_close = df_adj.pivot_table(index="date_trade", columns="instrument", values="close", aggfunc="last").sort_index()

        target_w   = target_w[~target_w.index.duplicated(keep="last")].copy()
        vol_exec   = vol_exec.groupby(level=0).last()
        high_exec  = high_exec.groupby(level=0).last()
        low_exec   = low_exec.groupby(level=0).last()
        real_close = real_close.groupby(level=0).last()
        adj_open   = adj_open.groupby(level=0).last()
        adj_close  = adj_close.groupby(level=0).last()

        common_days = (
            target_w.index
            .intersection(vol_exec.index)
            .intersection(high_exec.index)
            .intersection(low_exec.index)
            .intersection(adj_open.index)
            .intersection(adj_close.index)
            .intersection(real_close.index)
        )

        target_w   = target_w.reindex(common_days)
        vol_exec   = vol_exec.reindex(common_days)
        high_exec  = high_exec.reindex(common_days)
        low_exec   = low_exec.reindex(common_days)
        real_close = real_close.reindex(common_days)
        adj_open   = adj_open.reindex(common_days)
        adj_close  = adj_close.reindex(common_days)

        common_cols = (
            target_w.columns
            .intersection(vol_exec.columns)
            .intersection(high_exec.columns)
            .intersection(low_exec.columns)
            .intersection(adj_open.columns)
            .intersection(adj_close.columns)
            .intersection(real_close.columns)
        )

        target_w   = target_w[common_cols].copy()
        vol_exec   = vol_exec[common_cols].copy()
        high_exec  = high_exec[common_cols].copy()
        low_exec   = low_exec[common_cols].copy()
        real_close = real_close[common_cols].copy()
        adj_open   = adj_open[common_cols].copy()
        adj_close  = adj_close[common_cols].copy()

        last_real = real_close.ffill().iloc[-1]
        last_adj  = adj_close.ffill().iloc[-1]

        scale = (last_real / last_adj).replace([np.inf, -np.inf], np.nan)
        scale = scale.fillna(1.0)

        px_exec = adj_open.mul(scale, axis=1)  
        px_mtm  = adj_close.mul(scale, axis=1) 

        cash = capital_base
        shares = pd.Series(0, index=common_cols, dtype="int64")

        account_rows = []
        trade_rows = []
        position_rows = []

        prev_w = None

        def floor_lot(qty: int) -> int:
            if qty <= 0:
                return 0
            return (qty // LOT) * LOT

        for dt in target_w.index:
            w = target_w.loc[dt].fillna(0.0)

            rebalance = True if prev_w is None else (not w.equals(prev_w))

            p = px_exec.loc[dt]
            v = vol_exec.loc[dt]
            h = high_exec.loc[dt]
            l = low_exec.loc[dt]

            tradable = (
                p.notna() & (p > 0) &
                v.notna() & (v > 0) &
                h.notna() & l.notna() & (h != l)
            )

            if rebalance:
                mask_exec = tradable
                cur_value_exec = float(np.dot(
                    shares[mask_exec].to_numpy(dtype="float64"),
                    p[mask_exec].to_numpy(dtype="float64")
                ))
                total_asset_exec = cash + cur_value_exec

                target_value = total_asset_exec * w
                raw_target_shares = (target_value / p).replace([np.inf, -np.inf], np.nan).fillna(0.0)
                raw_target_shares = np.floor(raw_target_shares).astype("int64")
                target_shares = raw_target_shares.apply(lambda x: floor_lot(int(x)))

                target_shares[~tradable] = shares[~tradable]

                delta = target_shares - shares

                sells = delta[delta < 0]
                for ins, dsh in sells.items():
                    if not bool(tradable.get(ins, False)):
                        continue
                    px = float(p[ins])
                    qty = floor_lot(int(-dsh))
                    if qty <= 0:
                        continue

                    amount = qty * px
                    fee = amount * SELL_COST

                    cash += (amount - fee)
                    shares[ins] -= qty

                    trade_rows.append({
                        "date_trade": dt,
                        "instrument": ins,
                        "side": "SELL",
                        "shares": qty,
                        "price_exec": px,
                        "amount": float(amount),
                        "fee": float(fee)
                    })

                buys = delta[delta > 0]
                for ins, dsh in buys.items():
                    if not bool(tradable.get(ins, False)):
                        continue
                    px = float(p[ins])

                    qty_want = floor_lot(int(dsh))
                    if qty_want <= 0:
                        continue

                    unit_cost = px * (1.0 + BUY_COST)
                    max_qty = int(np.floor(cash / unit_cost))
                    max_qty = floor_lot(max_qty)

                    qty = min(qty_want, max_qty)
                    if qty <= 0:
                        continue

                    amount = qty * px
                    fee = amount * BUY_COST
                    total_cost = amount + fee

                    cash -= total_cost
                    shares[ins] += qty

                    trade_rows.append({
                        "date_trade": dt,
                        "instrument": ins,
                        "side": "BUY",
                        "shares": qty,
                        "price_exec": px,
                        "amount": float(amount),
                        "fee": float(fee)
                    })

            prev_w = w

            pm = px_mtm.loc[dt]
            mtm_mask = pm.notna() & (pm > 0)

            position_value = float(np.dot(
                shares[mtm_mask].to_numpy(dtype="float64"),
                pm[mtm_mask].to_numpy(dtype="float64")
            ))

            total_value = cash + position_value
            nav = total_value / capital_base if capital_base > 0 else np.nan

            account_rows.append({
                "date_trade": dt,
                "cash": float(cash),
                "position_value": float(position_value),
                "total_value": float(total_value),
                "nav": float(nav)
            })

            nonzero = shares[shares != 0]
            if len(nonzero) > 0:
                for ins, sh in nonzero.items():
                    px = pm.get(ins, np.nan)
                    mv = float(sh * px) if (pd.notna(px) and px > 0) else 0.0
                    wt = mv / total_value if total_value > 0 else np.nan
                    position_rows.append({
                        "date_trade": dt,
                        "instrument": ins,
                        "shares": int(sh),
                        "price_mtm": float(px) if pd.notna(px) else np.nan,
                        "market_value": float(mv),
                        "weight": float(wt) if wt == wt else np.nan
                    })

        df_account = pd.DataFrame(account_rows).set_index("date_trade").sort_index()

        df_position = (
            pd.DataFrame(position_rows).sort_values(["date_trade", "instrument"])
            if position_rows else
            pd.DataFrame(columns=["date_trade","instrument","shares","price_mtm","market_value","weight"])
        )

        df_trades = (
            pd.DataFrame(trade_rows).sort_values(["date_trade", "instrument", "side"])
            if trade_rows else
            pd.DataFrame(columns=["date_trade","instrument","side","shares","price_exec","amount","fee"])
        )

        return df_account, df_position, df_trades

    def backtest_output(self):

        trading_days = 252

        df = self.df_account.copy()

        if df.index.name != "date_trade":
            df.index = pd.to_datetime(df.index).rename("date_trade")
        else:
            df.index = pd.to_datetime(df.index)
        df = df.sort_index()

        if "nav" not in df.columns:
            if "total_value" not in df.columns:
                raise ValueError("df_account needs column 'nav' or 'total_value'.")
            cap = float(getattr(self, "capital_base", np.nan))
            if not np.isfinite(cap) or cap == 0:
                raise ValueError("capital_base not found on self; please set self.capital_base.")
            df["nav"] = df["total_value"] / cap

        nav = df["nav"].astype(float).replace([np.inf, -np.inf], np.nan).dropna()
        if nav.empty:
            raise ValueError("nav is empty after cleaning; check df_account.")

        sd = getattr(self, "sd", None)
        ed = getattr(self, "ed", None)
        if sd is None or ed is None:
            sd = nav.index.min().strftime("%Y-%m-%d")
            ed = nav.index.max().strftime("%Y-%m-%d")

        df_bm = dai.query(
            "SELECT date, close AS bm_value FROM cn_stock_index_bar1d WHERE instrument = '000300.SH'",
            filters={"date": [sd, ed]}
        ).df()

        if len(df_bm) == 0:
            raise ValueError("benchmark data is empty; check instrument or date filters.")

        df_bm["date_trade"] = pd.to_datetime(df_bm["date"]).dt.normalize()
        bm_value = (
            df_bm.sort_values("date_trade")
                .drop_duplicates("date_trade", keep="last")
                .set_index("date_trade")["bm_value"]
                .astype(float)
                .replace([np.inf, -np.inf], np.nan)
                .dropna()
        )

        idx = nav.index.intersection(bm_value.index)
        nav = nav.loc[idx]
        bm_value = bm_value.loc[idx]

        if nav.empty or bm_value.empty:
            raise ValueError("after aligning, nav or benchmark is empty (no overlapping dates).")

        bm_nav = (bm_value / bm_value.iloc[0]).rename("基准净值")

        ret = nav.pct_change().replace([np.inf, -np.inf], np.nan).dropna()
        bm_ret = bm_nav.pct_change().replace([np.inf, -np.inf], np.nan).dropna()

        ridx = ret.index.intersection(bm_ret.index)
        ret = ret.loc[ridx]
        bm_ret = bm_ret.loc[ridx]

        total_ret = float(nav.iloc[-1] / nav.iloc[0] - 1)
        n = int(ret.shape[0])
        ann_ret = float((nav.iloc[-1] / nav.iloc[0]) ** (trading_days / n) - 1) if n > 0 else np.nan
        vol = float(ret.std(ddof=1) * np.sqrt(trading_days)) if ret.std(ddof=1) != 0 else 0.0
        mdd = float((nav / nav.cummax() - 1).min())

        bm_total_ret = float(bm_nav.iloc[-1] / bm_nav.iloc[0] - 1)
        ex_total_ret = float(total_ret - bm_total_ret)

        table = pd.DataFrame([{
            "收益率": round(total_ret, 4),
            "年化收益率": round(ann_ret, 4) if ann_ret == ann_ret else np.nan,
            "波动率": round(vol, 4) if vol == vol else np.nan,
            "最大回撤": round(mdd, 4) if mdd == mdd else np.nan,
            "基准收益率": round(bm_total_ret, 4),
            "超额收益率": round(ex_total_ret, 4),
        }])

        c_table = bigcharts.Chart(
            data=table,
            type_="table",
            chart_options=dict(
                title_opts=opts.ComponentTitleOpts(title="回测绩效指标")
            ),
            y=list(table.columns)
        )

        nav_df = pd.DataFrame({
            "净值": nav.values,
            "基准净值": bm_nav.values,
        }, index=nav.index)

        nav_df["日期"] = nav_df.index.strftime("%Y-%m-%d")
        nav_df = nav_df.reset_index(drop=True)

        c_line = bigcharts.Chart(
            data=nav_df,
            type_="line",
            x="日期",
            y=["净值", "基准净值"],
            chart_options=dict(
                title_opts=opts.TitleOpts(title="净值曲线", pos_left="left"),
                legend_opts=opts.LegendOpts(pos_top="bottom"),
                xaxis_opts=opts.AxisOpts(split_number=8),
            ),
            series_options={
                "净值": {
                    "itemstyle_opts": opts.ItemStyleOpts(color="#e41a1c"),   # 红色
                    "linestyle_opts": opts.LineStyleOpts(width=2),
                    "symbol": "none"
            },
            "基准净值": {
                "itemstyle_opts": opts.ItemStyleOpts(color="#ffd700"),   # 黄色
                "linestyle_opts": opts.LineStyleOpts(width=2),
                "symbol": "none"
            },
            }
        )

        c_table.render(display=True)
        c_line.render(display=True)

        return nav_df

In [8]:
strategy_backtest = StrategyBacktest(strategy_param_dict)

收益率,年化收益率,波动率,最大回撤,基准收益率,超额收益率
1.7424,0.2271,0.2464,-0.2618,-0.1057,1.8481
