In [1]:
import sys, os
import numpy as np
import pandas as pd
import qlib

In [2]:
from pathlib import Path
scripts_dir = Path("/data/students/huzb/qlib/scripts")
print(scripts_dir.joinpath("get_data.py"))
assert scripts_dir.joinpath("get_data.py").exists()

/data/students/huzb/qlib/scripts/get_data.py


In [3]:
if not scripts_dir.joinpath("get_data.py").exists():
    # download get_data.py script
    scripts_dir = Path("~/tmp/qlib_code/scripts").expanduser().resolve()
    scripts_dir.mkdir(parents=True, exist_ok=True)
    import requests
    with requests.get("https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py") as resp:
        with open(scripts_dir.joinpath("get_data.py"), "wb") as fp:
            fp.write(resp.content)

In [4]:
from qlib.constant import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.utils import flatten_dict

In [5]:
provider_uri = "/data/students/huzb/qlib/qlib_data/cn_data"  # target_dir
# if not exists_qlib_data(provider_uri):
#     print(f"Qlib data is not found in {provider_uri}")
#     sys.path.append(str(scripts_dir))
#     from get_data import GetData
#     GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)

[19962:MainThread](2022-08-08 17:16:30,084) INFO - qlib.Initialization - [config.py:413] - default_conf: client.
[19962:MainThread](2022-08-08 17:16:30,090) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
[19962:MainThread](2022-08-08 17:16:30,091) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/data/students/huzb/qlib/qlib_data/cn_data')}


In [6]:
market = "csi300"
benchmark = "SH000300"

In [7]:
from qlib.data import D
from qlib.data.filter import ExpressionDFilter
from qlib.data.filter import NameDFilter

In [8]:
instruments = D.instruments(market='csi300')
fields = ['$close', '(Ref($close, -1)-$close)/$close', '(Ref($close, -2)-Ref($close, -1))/Ref($close, -1)', '(Ref($close, -3)-Ref($close, -2))/Ref($close, -2)', '(Ref($close, -4)-Ref($close, -3))/Ref($close, -3)', '(Ref($close, -5)-Ref($close, -4))/Ref($close, -4)', '(Ref($close, -6)-Ref($close, -5))/Ref($close, -5)', '(Ref($close, -7)-Ref($close, -6))/Ref($close, -6)']
f_d = D.features(instruments, fields, start_time='2020-01-01', end_time='2020-08-01', freq='day')
df = f_d
df.index = df.index.get_level_values('datetime')
print(df.index.min(), df.index.max())

start_time = pd.to_datetime(df.index.min())
end_time = pd.to_datetime(df.index.max())
print(start_time.strftime('%Y-%m-%d'), end_time.strftime('%Y-%m-%d'))

2020-01-02 00:00:00 2020-07-31 00:00:00
2020-01-02 2020-07-31


In [9]:
experiment_name="online_srv"

In [10]:
###################################
# train model
###################################
data_handler_config = {
    "start_time": start_time, # 
    "end_time": end_time,
    "fit_start_time": start_time,
    "fit_end_time": "2020-06-01",
    "instruments": market,
    "infer_processors": [
      {
        "class": "RobustZScoreNorm",
        "kwargs": {
          "fields_group": "feature",
          "clip_outlier": True
        }
      },
      {
        "class": "Fillna",
        "kwargs": {
          "fields_group": "feature"
        }
      }
    ],
    "learn_processors": [
      {
        "class": "DropnaLabel"
      },
      {
        "class": "CSRankNorm",
        "kwargs": {
          "fields_group": "label"
        }
      }
    ],
    "label": [
      "Ref($close, -2) / Ref($close, -1) - 1"
    ]
}

task = {   
    "model": {
        "class": "LSTM",
        "module_path": "qlib.contrib.model.pytorch_lstm",
        "kwargs": {
            "d_feat": 6,
            "hidden_size": 64,
            "num_layers": 2,
            "dropout": 0.0,
            "dec_dropout": 0.0,
            "n_epochs": 200,
            "lr": 1e-3,
            "early_stop": 20,
            "batch_size": 800,
            "metric": "loss",
            "loss": "mse",
            "optimizer": "adam",
            "GPU": 0
        },
    },
    "dataset": {
        "class": "DatasetH",
        "module_path": "qlib.data.dataset",
        "kwargs": {
            "handler": {
                "class": "Alpha360",
                "module_path": "qlib.contrib.data.handler",
                "kwargs": data_handler_config,
            },
            "segments": {
                "train": (start_time, "2020-06-01"),
                "valid": ("2020-06-02", "2020-06-30"),
                "test": ("2020-07-01", "2020-07-31"),
            },
        },
    },
}

# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])

[19962:MainThread](2022-08-08 17:16:33,694) INFO - qlib.LSTM - [pytorch_lstm.py:58] - LSTM pytorch version...
[19962:MainThread](2022-08-08 17:16:33,727) INFO - qlib.LSTM - [pytorch_lstm.py:75] - LSTM parameters setting:
d_feat : 6
hidden_size : 64
num_layers : 2
dropout : 0.0
n_epochs : 200
lr : 0.001
metric : loss
batch_size : 800
early_stop : 20
optimizer : adam
loss_type : mse
visible_GPU : 0
use_GPU : True
seed : None
[19962:MainThread](2022-08-08 17:16:57,250) INFO - qlib.timer - [log.py:117] - Time cost: 20.311s | Loading data Done
  result = np.apply_along_axis(_nanmedian1d, axis, a, overwrite_input)
[19962:MainThread](2022-08-08 17:17:03,969) INFO - qlib.timer - [log.py:117] - Time cost: 6.623s | RobustZScoreNorm Done
[19962:MainThread](2022-08-08 17:17:04,010) INFO - qlib.timer - [log.py:117] - Time cost: 0.039s | Fillna Done
[19962:MainThread](2022-08-08 17:17:04,091) INFO - qlib.timer - [log.py:117] - Time cost: 0.044s | DropnaLabel Done
A value is trying to be set on a cop

In [11]:
# start exp to train model

experiment_id = 'cn_backtest'
# experiment_name: Optional[Text] = None,
# recorder_id: Optional[Text] = None,

# with R.start(experiment_name=experiment_name, experimen
# t_id=experiment_id):
with R.start(experiment_name=experiment_name):
    R.log_params(**flatten_dict(task))
    model.fit(dataset)
    R.save_objects(trained_model=model)
    rid = R.get_recorder().id
    # prediction
    recorder = R.get_recorder()
    sr = SignalRecord(model, dataset, recorder)
    sr.generate()

[19962:MainThread](2022-08-08 17:17:04,569) INFO - qlib.workflow - [expm.py:315] - <mlflow.tracking.client.MlflowClient object at 0x7fe9b887f9d0>
[19962:MainThread](2022-08-08 17:17:04,592) INFO - qlib.workflow - [exp.py:257] - Experiment 1 starts running ...
[19962:MainThread](2022-08-08 17:17:04,786) INFO - qlib.workflow - [recorder.py:295] - Recorder e8b77e0db7264931a75e04ad020343fa starts running under Experiment 1 ...
[19962:MainThread](2022-08-08 17:17:05,249) INFO - qlib.LSTM - [pytorch_lstm.py:236] - training...
[19962:MainThread](2022-08-08 17:17:05,251) INFO - qlib.LSTM - [pytorch_lstm.py:240] - Epoch0:
[19962:MainThread](2022-08-08 17:17:05,253) INFO - qlib.LSTM - [pytorch_lstm.py:241] - training...
[19962:MainThread](2022-08-08 17:17:05,800) INFO - qlib.LSTM - [pytorch_lstm.py:243] - evaluating...
[19962:MainThread](2022-08-08 17:17:06,008) INFO - qlib.LSTM - [pytorch_lstm.py:246] - train -0.997005, valid -0.993565
[19962:MainThread](2022-08-08 17:17:06,013) INFO - qlib.LST

'The following are prediction results of the LSTM model.'
                          score
datetime   instrument          
2020-07-01 SH600000   -0.014263
           SH600004    0.001635
           SH600009   -0.013610
           SH600010    0.019348
           SH600011    0.001502


## WeekTopkDropoutStrategy

In [15]:
###################################
# prediction, backtest & analysis
###################################
port_analysis_config = {
    "executor": {
        "class": "SimulatorExecutor",
        "module_path": "qlib.backtest.executor",
        "kwargs": {
            "time_per_step": "day",
            "generate_portfolio_metrics": True,
        },
    },
    "strategy": {
        "class": "WeekTopkDropoutStrategy",
        "module_path": "qlib.contrib.strategy.signal_strategy",
        "kwargs": {
            "model": model,
            "dataset": dataset,
            "topk": 50,
            "n_drop": 5,
        },
    },
    "backtest": {
        "start_time": "2020-07-01",
        "end_time": "2020-07-31",
        "account": 100000000,
        "benchmark": benchmark,
        "exchange_kwargs": {
            "freq": "day",
            "limit_threshold": 0.095,
            "deal_price": "close",
            "open_cost": 0.0005,
            "close_cost": 0.0015,
            "min_cost": 5,
        },
    },
}

# backtest and analysis
with R.start(experiment_name="backtest_analysis"):
    recorder = R.get_recorder(recorder_id=rid, experiment_name="online_srv")
    model = recorder.load_object("trained_model")

    # prediction
    recorder = R.get_recorder()
    ba_rid = recorder.id
    sr = SignalRecord(model, dataset, recorder)
    sr.generate()

    # backtest & analysis
    par = PortAnaRecord(recorder, port_analysis_config, "day")
    par.generate()

[19962:MainThread](2022-08-09 13:57:03,062) INFO - qlib.workflow - [expm.py:315] - <mlflow.tracking.client.MlflowClient object at 0x7fe8184ec8e0>
[19962:MainThread](2022-08-09 13:57:03,236) INFO - qlib.workflow - [exp.py:257] - Experiment 2 starts running ...
[19962:MainThread](2022-08-09 13:57:03,290) INFO - qlib.workflow - [recorder.py:295] - Recorder 19c7504e6c844249aec1a45bb9739e37 starts running under Experiment 2 ...

RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at  /opt/conda/conda-bld/pytorch_1646755853042/work/aten/src/ATen/native/cudnn/RNN.cpp:926.)

[19962:MainThread](2022-08-09 13:57:06,035) INFO - qlib.workflow - [record_temp.py:194] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 2
[19962:MainThread](2022-08-09 13:57:06,174) INFO - qlib.backtest caller - 

'The following are prediction results of the LSTM model.'
                          score
datetime   instrument          
2020-07-01 SH600000   -0.014263
           SH600004    0.001635
           SH600009   -0.013610
           SH600010    0.019348
           SH600011    0.001502



`model` `dataset` is deprecated; use `signal`.


RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at  /opt/conda/conda-bld/pytorch_1646755853042/work/aten/src/ATen/native/cudnn/RNN.cpp:926.)



backtest loop:   0%|          | 0/23 [00:00<?, ?it/s]

[19962:MainThread](2022-08-09 13:58:44,659) INFO - qlib.workflow - [record_temp.py:499] - Portfolio analysis record 'port_analysis_1day.pkl' has been saved as the artifact of the Experiment 2
[19962:MainThread](2022-08-09 13:58:44,679) INFO - qlib.workflow - [record_temp.py:524] - Indicator analysis record 'indicator_analysis_1day.pkl' has been saved as the artifact of the Experiment 2
[19962:MainThread](2022-08-09 13:58:44,745) INFO - qlib.timer - [log.py:117] - Time cost: 0.012s | waiting `async_log` Done


'The following are analysis results of benchmark return(1day).'
                       risk
mean               0.005478
std                0.022599
annualized_return  1.303725
information_ratio  3.739509
max_drawdown      -0.071481
'The following are analysis results of the excess return without cost(1day).'
                        risk
mean                0.010493
std                 0.010615
annualized_return   2.497304
information_ratio  15.249321
max_drawdown       -0.005057
'The following are analysis results of the excess return with cost(1day).'
                        risk
mean                0.010300
std                 0.010651
annualized_return   2.451487
information_ratio  14.918802
max_drawdown       -0.005139
'The following are analysis results of indicators(1day).'
     value
ffr    1.0
pa     0.0
pos    0.0


In [16]:
from qlib.contrib.report import analysis_model, analysis_position
from qlib.data import D
recorder = R.get_recorder(recorder_id=ba_rid, experiment_name="online_srv")
print(recorder)
pred_df = recorder.load_object("pred.pkl")
pred_df_dates = pred_df.index.get_level_values(level='datetime')
report_normal_df = recorder.load_object("portfolio_analysis/report_normal_1day.pkl")
positions = recorder.load_object("portfolio_analysis/positions_normal_1day.pkl")
analysis_df = recorder.load_object("portfolio_analysis/port_analysis_1day.pkl")

{'class': 'Recorder', 'id': '19c7504e6c844249aec1a45bb9739e37', 'name': 'mlflow_recorder', 'experiment_id': '1', 'start_time': '2022-08-09 13:57:03', 'end_time': '2022-08-09 13:58:44', 'status': 'FINISHED'}


In [17]:
report_normal_df

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2020-07-01,99954400.0,0.0,91200000.0,0.912,45600.0,0.000456,91200000.0,8754400.0,0.02013
2020-07-02,102662000.0,0.027313,116261100.0,0.250725,68021.260641,0.000224,99209670.0,3452324.0,0.020731
2020-07-03,105459900.0,0.027452,136946300.0,0.201489,88390.320112,0.000198,102660300.0,2799614.0,0.019318
2020-07-06,110895100.0,0.05162,146316100.0,0.088847,97070.472362,8.2e-05,109483500.0,1411663.0,0.056677
2020-07-07,112897700.0,0.018247,167646800.0,0.19235,117994.352562,0.000189,112320500.0,577205.5,0.006004
2020-07-08,117114800.0,0.037531,187644200.0,0.177129,137974.229252,0.000177,116592600.0,522162.2,0.016149
2020-07-09,122152800.0,0.043133,201239400.0,0.116084,151494.257991,0.000115,121794500.0,358288.8,0.013986
2020-07-10,122950700.0,0.0067,221613200.0,0.16679,171962.24046,0.000168,122424600.0,526150.7,-0.018105
2020-07-13,127635400.0,0.038275,242822100.0,0.172499,193194.483817,0.000173,127083800.0,551609.6,0.021003
2020-07-14,127978000.0,0.002917,272491200.0,0.232452,222986.142598,0.000233,127211000.0,766958.4,-0.009534


In [167]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import copy
import warnings
import numpy as np
import pandas as pd
from datetime import datetime
from dateutil.relativedelta import relativedelta

from typing import Dict, List, Text, Tuple, Union

from qlib.data import D
from qlib.data.dataset import Dataset
from qlib.model.base import BaseModel
from qlib.strategy.base import BaseStrategy
from qlib.backtest.position import Position
from qlib.backtest.signal import Signal, create_signal_from
from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO
from qlib.log import get_module_logger
from qlib.utils import get_pre_trading_date, load_dataset
from qlib.contrib.strategy.order_generator import OrderGenWOInteract
from qlib.contrib.strategy.optimizer import EnhancedIndexingOptimizer



week_instruments = D.instruments('csi300')


def get_week_HighAndLow(start_time, end_time, datetime, instruments):
    # 计算一周后的涨跌幅
    # instruments = D.instruments(market='all')
    datetime = datetime.strftime('%Y-%m-%d')
    start_time = start_time.strftime('%Y-%m-%d')
    end_time = (end_time + relativedelta(days=20)).strftime('%Y-%m-%d')
    fields = ['(Ref($close, -1)-$close)/$close', '(Ref($close, -2)-Ref($close, -1))/Ref($close, -1)',
              '(Ref($close, -3)-Ref($close, -2))/Ref($close, -2)',
              '(Ref($close, -4)-Ref($close, -3))/Ref($close, -3)',
              '(Ref($close, -5)-Ref($close, -4))/Ref($close, -4)',
              '(Ref($close, -6)-Ref($close, -5))/Ref($close, -5)',
              '(Ref($close, -7)-Ref($close, -6))/Ref($close, -6)']
    f_d = D.features(instruments, fields, start_time=start_time, end_time=end_time, freq='day')
    
    '''
    # end_time向后推7天
    f_d_index = str(f_d.index)
    j = 1
    for i in range(0, 6):
        iftime = end_time.strftime('%Y-%m-%d') + relativedelta(days=j)
        if iftime in f_d_index:
            i+=1
            j+=1
        else:
            j+=1
    
    fields = ['(Ref($close, -1)-$close)/$close', '(Ref($close, -2)-Ref($close, -1))/Ref($close, -1)',
              '(Ref($close, -3)-Ref($close, -2))/Ref($close, -2)',
              '(Ref($close, -4)-Ref($close, -3))/Ref($close, -3)',
              '(Ref($close, -5)-Ref($close, -4))/Ref($close, -4)',
              '(Ref($close, -6)-Ref($close, -5))/Ref($close, -5)',
              '(Ref($close, -7)-Ref($close, -6))/Ref($close, -6)']
    '''

    # wfd = D.features(instruments, fields, start_time=start_time, end_time=iftime, freq='day')
    wha = f_d.loc[(slice(None), datetime), :]
    wha.loc[:, 'A_Week_HighAndLow'] = wha.apply(lambda x: x.sum(), axis=1)
    wha = wha.loc[:, 'A_Week_HighAndLow'].droplevel('datetime')
    return wha


class BaseSignalStrategy(BaseStrategy):
    def __init__(
        self,
        *,
        signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] = None,
        model=None,
        dataset=None,
        risk_degree: float = 0.95,
        trade_exchange=None,
        level_infra=None,
        common_infra=None,
        **kwargs,
    ):
        """
        Parameters
        -----------
        signal :
            the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
            the decision of the strategy will base on the given signal
        risk_degree : float
            position percentage of total value.
        trade_exchange : Exchange
            exchange that provides market info, used to deal order and generate report
            - If `trade_exchange` is None, self.trade_exchange will be set with common_infra
            - It allowes different trade_exchanges is used in different executions.
            - For example:
                - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
                - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.

        """
        super().__init__(level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs)

        self.risk_degree = risk_degree

        # This is trying to be compatible with previous version of qlib task config
        if model is not None and dataset is not None:
            warnings.warn("`model` `dataset` is deprecated; use `signal`.", DeprecationWarning)
            signal = model, dataset

        self.signal: Signal = create_signal_from(signal)

    def get_risk_degree(self, trade_step=None):
        """get_risk_degree
        Return the proportion of your total value you will used in investment.
        Dynamically risk_degree will result in Market timing.
        """
        # It will use 95% amount of your total value by default
        return self.risk_degree


class TopkDropoutStrategy(BaseSignalStrategy):
    # TODO:
    # 1. Supporting leverage the get_range_limit result from the decision
    # 2. Supporting alter_outer_trade_decision
    # 3. Supporting checking the availability of trade decision
    def __init__(
        self,
        *,
        topk,
        n_drop,
        method_sell="bottom",
        method_buy="top",
        hold_thresh=1,
        only_tradable=False,
        **kwargs,
    ):
        """
        Parameters
        -----------
        topk : int
            the number of stocks in the portfolio.
        n_drop : int
            number of stocks to be replaced in each trading date.
        method_sell : str
            dropout method_sell, random/bottom.
        method_buy : str
            dropout method_buy, random/top.
        hold_thresh : int
            minimum holding days
            before sell stock , will check current.get_stock_count(order.stock_id) >= self.hold_thresh.
        only_tradable : bool
            will the strategy only consider the tradable stock when buying and selling.
            if only_tradable:
                strategy will make buy sell decision without checking the tradable state of the stock.
            else:
                strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
        """
        super().__init__(**kwargs)
        self.topk = topk
        self.n_drop = n_drop
        self.method_sell = method_sell
        self.method_buy = method_buy
        self.hold_thresh = hold_thresh
        self.only_tradable = only_tradable

    def generate_trade_decision(self, execute_result=None):
        # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
        trade_step = self.trade_calendar.get_trade_step()
        trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
        pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
        pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
        # NOTE: the current version of topk dropout strategy can't handle pd.DataFrame(multiple signal)
        # So it only leverage the first col of signal
        if isinstance(pred_score, pd.DataFrame):
            pred_score = pred_score.iloc[:, 0]
        if pred_score is None:
            return TradeDecisionWO([], self)
        if self.only_tradable:
            # If The strategy only consider tradable stock when make decision
            # It needs following actions to filter stocks
            def get_first_n(li, n, reverse=False):
                cur_n = 0
                res = []
                for si in reversed(li) if reverse else li:
                    if self.trade_exchange.is_stock_tradable(
                        stock_id=si, start_time=trade_start_time, end_time=trade_end_time
                    ):
                        res.append(si)
                        cur_n += 1
                        if cur_n >= n:
                            break
                return res[::-1] if reverse else res

            def get_last_n(li, n):
                return get_first_n(li, n, reverse=True)

            def filter_stock(li):
                return [
                    si
                    for si in li
                    if self.trade_exchange.is_stock_tradable(
                        stock_id=si, start_time=trade_start_time, end_time=trade_end_time
                    )
                ]

        else:
            # Otherwise, the stock will make decision with out the stock tradable info
            def get_first_n(li, n):
                return list(li)[:n]

            def get_last_n(li, n):
                return list(li)[-n:]

            def filter_stock(li):
                return li

        current_temp = copy.deepcopy(self.trade_position)
        # generate order list for this adjust date
        sell_order_list = []
        buy_order_list = []
        # load score
        cash = current_temp.get_cash()
        current_stock_list = current_temp.get_stock_list()
        # last position (sorted by score)
        last = pred_score.reindex(current_stock_list).sort_values(ascending=False).index
        # The new stocks today want to buy **at most**
        if self.method_buy == "top":
            today = get_first_n(
                pred_score[~pred_score.index.isin(last)].sort_values(ascending=False).index,
                self.n_drop + self.topk - len(last),
            )
        elif self.method_buy == "random":
            topk_candi = get_first_n(pred_score.sort_values(ascending=False).index, self.topk)
            candi = list(filter(lambda x: x not in last, topk_candi))
            n = self.n_drop + self.topk - len(last)
            try:
                today = np.random.choice(candi, n, replace=False)
            except ValueError:
                today = candi
        else:
            raise NotImplementedError(f"This type of input is not supported")
        # combine(new stocks + last stocks),  we will drop stocks from this list
        # In case of dropping higher score stock and buying lower score stock.
        comb = pred_score.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index

        # Get the stock list we really want to sell (After filtering the case that we sell high and buy low)
        if self.method_sell == "bottom":
            sell = last[last.isin(get_last_n(comb, self.n_drop))]
        elif self.method_sell == "random":
            candi = filter_stock(last)
            try:
                sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])
            except ValueError:  # No enough candidates
                sell = candi
        else:
            raise NotImplementedError(f"This type of input is not supported")

        # Get the stock list we really want to buy
        buy = today[: len(sell) + self.topk - len(last)]
        for code in current_stock_list:
            if not self.trade_exchange.is_stock_tradable(
                stock_id=code, start_time=trade_start_time, end_time=trade_end_time
            ):
                continue
            if code in sell:
                # check hold limit
                time_per_step = self.trade_calendar.get_freq()
                if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh:
                    continue
                # sell order
                sell_amount = current_temp.get_stock_amount(code=code)
                factor = self.trade_exchange.get_factor(
                    stock_id=code, start_time=trade_start_time, end_time=trade_end_time
                )
                # sell_amount = self.trade_exchange.round_amount_by_trade_unit(sell_amount, factor)
                sell_order = Order(
                    stock_id=code,
                    amount=sell_amount,
                    start_time=trade_start_time,
                    end_time=trade_end_time,
                    direction=Order.SELL,  # 0 for sell, 1 for buy
                )
                # is order executable
                if self.trade_exchange.check_order(sell_order):
                    sell_order_list.append(sell_order)
                    trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
                        sell_order, position=current_temp
                    )
                    # update cash
                    cash += trade_val - trade_cost
        # buy new stock
        # note the current has been changed
        current_stock_list = current_temp.get_stock_list()
        value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0

        # open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
        # consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
        # value = value / (1+self.trade_exchange.open_cost) # set open_cost limit
        for code in buy:
            # check is stock suspended
            if not self.trade_exchange.is_stock_tradable(
                stock_id=code, start_time=trade_start_time, end_time=trade_end_time
            ):
                continue
            # buy order
            buy_price = self.trade_exchange.get_deal_price(
                stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY
            )
            buy_amount = value / buy_price
            factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
            buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
            buy_order = Order(
                stock_id=code,
                amount=buy_amount,
                start_time=trade_start_time,
                end_time=trade_end_time,
                direction=Order.BUY,  # 1 for buy
            )
            buy_order_list.append(buy_order)
        return TradeDecisionWO(sell_order_list + buy_order_list, self)


class WeekTopkDropoutStrategy(BaseSignalStrategy):
    # TODO:
    # 1. Supporting leverage the get_range_limit result from the decision
    # 2. Supporting alter_outer_trade_decision
    # 3. Supporting checking the availability of trade decision
    def __init__(
            self,
            *,
            topk,
            n_drop,
            method_sell="bottom",
            method_buy="top",
            hold_thresh=1,
            only_tradable=False,
            **kwargs,
    ):
        """
        Parameters
        -----------
        topk : int
            the number of stocks in the portfolio.
        n_drop : int
            number of stocks to be replaced in each trading date.
        method_sell : str
            dropout method_sell, random/bottom.
        method_buy : str
            dropout method_buy, random/top.
        hold_thresh : int
            minimum holding days
            before sell stock , will check current.get_stock_count(order.stock_id) >= self.hold_thresh.
        only_tradable : bool
            will the strategy only consider the tradable stock when buying and selling.
            if only_tradable:
                strategy will make buy sell decision without checking the tradable state of the stock.
            else:
                strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
        """
        super().__init__(**kwargs)
        self.topk = topk
        self.n_drop = n_drop
        self.method_sell = method_sell
        self.method_buy = method_buy
        self.hold_thresh = hold_thresh
        self.only_tradable = only_tradable



    def generate_trade_decision(self, execute_result=None):
        # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
        trade_step = self.trade_calendar.get_trade_step()
        trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
        pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=-1)
        pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)

        wha = get_week_HighAndLow(start_time=pred_start_time, end_time=pred_end_time, datetime=pred_start_time, instruments=week_instruments)
        if wha is None:
            print("A_Week_HighAndLow is None")
        pred = pd.concat([wha, pred_score], axis=1)
        # print(pred.columns.values)
        # pred.columns = ['A_Week_HighAndLow', 'score']
        print(pred)

        def compute_score(a, b):
            # a是wha，b是pred
            return a * (2 / 3) + b * (1 / 3)

        pred['value'] = pred.apply(lambda row: compute_score(row['A_Week_HighAndLow'], row[0]), axis=1).copy()
        pred_score = pred['value']
        print(pred_score)

        # NOTE: the current version of topk dropout strategy can't handle pd.DataFrame(multiple signal)
        # So it only leverage the first col of signal

        if isinstance(pred_score, pd.DataFrame):
            pred_score = pred_score.iloc[:, 0]
        if pred_score is None:
            print("pred_score is None")
            # print(trade_step, trade_start_time, trade_end_time)
            # print(pred_start_time, pred_end_time)
            return TradeDecisionWO([], self)
        if self.only_tradable:
            # If The strategy only consider tradable stock when make decision
            # It needs following actions to filter stocks
            def get_first_n(l, n, reverse=False):
                cur_n = 0
                res = []
                for si in reversed(l) if reverse else l:
                    if self.trade_exchange.is_stock_tradable(
                            stock_id=si, start_time=trade_start_time, end_time=trade_end_time
                    ):
                        res.append(si)
                        cur_n += 1
                        if cur_n >= n:
                            break
                return res[::-1] if reverse else res

            def get_last_n(l, n):
                return get_first_n(l, n, reverse=True)

            def filter_stock(l):
                return [
                    si
                    for si in l
                    if self.trade_exchange.is_stock_tradable(
                        stock_id=si, start_time=trade_start_time, end_time=trade_end_time
                    )
                ]

        else:
            # Otherwise, the stock will make decision with out the stock tradable info
            def get_first_n(l, n):
                return list(l)[:n]

            def get_last_n(l, n):
                return list(l)[-n:]

            def filter_stock(l):
                return l

        current_temp = copy.deepcopy(self.trade_position)
        # generate order list for this adjust date
        sell_order_list = []
        buy_order_list = []
        # load score
        cash = current_temp.get_cash()
        current_stock_list = current_temp.get_stock_list()
        # last position (sorted by score)
        last = pred_score.reindex(current_stock_list).sort_values(ascending=False).index
        # The new stocks today want to buy **at most**
        if self.method_buy == "top":
            today = get_first_n(
                pred_score[~pred_score.index.isin(last)].sort_values(ascending=False).index,
                self.n_drop + self.topk - len(last),
            )
        elif self.method_buy == "random":
            topk_candi = get_first_n(pred_score.sort_values(ascending=False).index, self.topk)
            candi = list(filter(lambda x: x not in last, topk_candi))
            n = self.n_drop + self.topk - len(last)
            try:
                today = np.random.choice(candi, n, replace=False)
            except ValueError:
                today = candi
        else:
            raise NotImplementedError(f"This type of input is not supported")
        # combine(new stocks + last stocks),  we will drop stocks from this list
        # In case of dropping higher score stock and buying lower score stock.
        comb = pred_score.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index

        # Get the stock list we really want to sell (After filtering the case that we sell high and buy low)
        if self.method_sell == "bottom":
            sell = last[last.isin(get_last_n(comb, self.n_drop))]
        elif self.method_sell == "random":
            candi = filter_stock(last)
            try:
                sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])
            except ValueError:  # No enough candidates
                sell = candi
        else:
            raise NotImplementedError(f"This type of input is not supported")

        # Get the stock list we really want to buy
        buy = today[: len(sell) + self.topk - len(last)]
        for code in current_stock_list:
            if not self.trade_exchange.is_stock_tradable(
                    stock_id=code, start_time=trade_start_time, end_time=trade_end_time
            ):
                continue
            if code in sell:
                # check hold limit
                time_per_step = self.trade_calendar.get_freq()
                if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh:
                    continue
                # sell order
                sell_amount = current_temp.get_stock_amount(code=code)
                factor = self.trade_exchange.get_factor(
                    stock_id=code, start_time=trade_start_time, end_time=trade_end_time
                )
                # sell_amount = self.trade_exchange.round_amount_by_trade_unit(sell_amount, factor)
                sell_order = Order(
                    stock_id=code,
                    amount=sell_amount,
                    start_time=trade_start_time,
                    end_time=trade_end_time,
                    direction=Order.SELL,  # 0 for sell, 1 for buy
                )
                # is order executable
                if self.trade_exchange.check_order(sell_order):
                    sell_order_list.append(sell_order)
                    trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
                        sell_order, position=current_temp
                    )
                    # update cash
                    cash += trade_val - trade_cost
        # buy new stock
        # note the current has been changed
        current_stock_list = current_temp.get_stock_list()
        value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0

        # open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
        # consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
        # value = value / (1+self.trade_exchange.open_cost) # set open_cost limit
        for code in buy:
            # check is stock suspended
            if not self.trade_exchange.is_stock_tradable(
                    stock_id=code, start_time=trade_start_time, end_time=trade_end_time
            ):
                continue
            # buy order
            buy_price = self.trade_exchange.get_deal_price(
                stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY
            )
            buy_amount = value / buy_price
            factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
            buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
            buy_order = Order(
                stock_id=code,
                amount=buy_amount,
                start_time=trade_start_time,
                end_time=trade_end_time,
                direction=Order.BUY,  # 1 for buy
            )
            buy_order_list.append(buy_order)
        return TradeDecisionWO(sell_order_list + buy_order_list, self)


class WeightStrategyBase(BaseSignalStrategy):
    # TODO:
    # 1. Supporting leverage the get_range_limit result from the decision
    # 2. Supporting alter_outer_trade_decision
    # 3. Supporting checking the availability of trade decision
    def __init__(
        self,
        *,
        order_generator_cls_or_obj=OrderGenWOInteract,
        **kwargs,
    ):
        """
        signal :
            the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
            the decision of the strategy will base on the given signal
        trade_exchange : Exchange
            exchange that provides market info, used to deal order and generate report
            - If `trade_exchange` is None, self.trade_exchange will be set with common_infra
            - It allowes different trade_exchanges is used in different executions.
            - For example:
                - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
                - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
        """
        super().__init__(**kwargs)

        if isinstance(order_generator_cls_or_obj, type):
            self.order_generator = order_generator_cls_or_obj()
        else:
            self.order_generator = order_generator_cls_or_obj

    def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
        """
        Generate target position from score for this date and the current position.The cash is not considered in the position
        Parameters
        -----------
        score : pd.Series
            pred score for this trade date, index is stock_id, contain 'score' column.
        current : Position()
            current position.
        trade_exchange : Exchange()
        trade_date : pd.Timestamp
            trade date.
        """
        raise NotImplementedError()

    def generate_trade_decision(self, execute_result=None):
        # generate_trade_decision
        # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list

        # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
        trade_step = self.trade_calendar.get_trade_step()
        trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
        pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
        pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
        if pred_score is None:
            return TradeDecisionWO([], self)
        current_temp = copy.deepcopy(self.trade_position)
        assert isinstance(current_temp, Position)  # Avoid InfPosition

        target_weight_position = self.generate_target_weight_position(
            score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time
        )
        order_list = self.order_generator.generate_order_list_from_target_weight_position(
            current=current_temp,
            trade_exchange=self.trade_exchange,
            risk_degree=self.get_risk_degree(trade_step),
            target_weight_position=target_weight_position,
            pred_start_time=pred_start_time,
            pred_end_time=pred_end_time,
            trade_start_time=trade_start_time,
            trade_end_time=trade_end_time,
        )
        return TradeDecisionWO(order_list, self)


class EnhancedIndexingStrategy(WeightStrategyBase):

    """Enhanced Indexing Strategy

    Enhanced indexing combines the arts of active management and passive management,
    with the aim of outperforming a benchmark index (e.g., S&P 500) in terms of
    portfolio return while controlling the risk exposure (a.k.a. tracking error).

    Users need to prepare their risk model data like below:

    ├── /path/to/riskmodel
    ├──── 20210101
    ├────── factor_exp.{csv|pkl|h5}
    ├────── factor_cov.{csv|pkl|h5}
    ├────── specific_risk.{csv|pkl|h5}
    ├────── blacklist.{csv|pkl|h5}  # optional

    The risk model data can be obtained from risk data provider. You can also use
    `qlib.model.riskmodel.structured.StructuredCovEstimator` to prepare these data.

    Args:
        riskmodel_path (str): risk model path
        name_mapping (dict): alternative file names
    """

    FACTOR_EXP_NAME = "factor_exp.pkl"
    FACTOR_COV_NAME = "factor_cov.pkl"
    SPECIFIC_RISK_NAME = "specific_risk.pkl"
    BLACKLIST_NAME = "blacklist.pkl"

    def __init__(
        self,
        *,
        riskmodel_root,
        market="csi500",
        turn_limit=None,
        name_mapping={},
        optimizer_kwargs={},
        verbose=False,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.logger = get_module_logger("EnhancedIndexingStrategy")

        self.riskmodel_root = riskmodel_root
        self.market = market
        self.turn_limit = turn_limit

        self.factor_exp_path = name_mapping.get("factor_exp", self.FACTOR_EXP_NAME)
        self.factor_cov_path = name_mapping.get("factor_cov", self.FACTOR_COV_NAME)
        self.specific_risk_path = name_mapping.get("specific_risk", self.SPECIFIC_RISK_NAME)
        self.blacklist_path = name_mapping.get("blacklist", self.BLACKLIST_NAME)

        self.optimizer = EnhancedIndexingOptimizer(**optimizer_kwargs)

        self.verbose = verbose

        self._riskdata_cache = {}

    def get_risk_data(self, date):

        if date in self._riskdata_cache:
            return self._riskdata_cache[date]

        root = self.riskmodel_root + "/" + date.strftime("%Y%m%d")
        if not os.path.exists(root):
            return None

        factor_exp = load_dataset(root + "/" + self.factor_exp_path, index_col=[0])
        factor_cov = load_dataset(root + "/" + self.factor_cov_path, index_col=[0])
        specific_risk = load_dataset(root + "/" + self.specific_risk_path, index_col=[0])

        if not factor_exp.index.equals(specific_risk.index):
            # NOTE: for stocks missing specific_risk, we always assume it have the highest volatility
            specific_risk = specific_risk.reindex(factor_exp.index, fill_value=specific_risk.max())

        universe = factor_exp.index.tolist()

        blacklist = []
        if os.path.exists(root + "/" + self.blacklist_path):
            blacklist = load_dataset(root + "/" + self.blacklist_path).index.tolist()

        self._riskdata_cache[date] = factor_exp.values, factor_cov.values, specific_risk.values, universe, blacklist

        return self._riskdata_cache[date]

    def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):

        trade_date = trade_start_time
        pre_date = get_pre_trading_date(trade_date, future=True)  # previous trade date

        # load risk data
        outs = self.get_risk_data(pre_date)
        if outs is None:
            self.logger.warning(f"no risk data for {pre_date:%Y-%m-%d}, skip optimization")
            return None
        factor_exp, factor_cov, specific_risk, universe, blacklist = outs

        # transform score
        # NOTE: for stocks missing score, we always assume they have the lowest score
        score = score.reindex(universe).fillna(score.min()).values

        # get current weight
        # NOTE: if a stock is not in universe, its current weight will be zero
        cur_weight = current.get_stock_weight_dict(only_stock=False)
        cur_weight = np.array([cur_weight.get(stock, 0) for stock in universe])
        assert all(cur_weight >= 0), "current weight has negative values"
        cur_weight = cur_weight / self.get_risk_degree(trade_date)  # sum of weight should be risk_degree
        if cur_weight.sum() > 1 and self.verbose:
            self.logger.warning(f"previous total holdings excess risk degree (current: {cur_weight.sum()})")

        # load bench weight
        bench_weight = D.features(
            D.instruments("all"), [f"${self.market}_weight"], start_time=pre_date, end_time=pre_date
        ).squeeze()
        bench_weight.index = bench_weight.index.droplevel(level="datetime")
        bench_weight = bench_weight.reindex(universe).fillna(0).values

        # whether stock tradable
        # NOTE: currently we use last day volume to check whether tradable
        tradable = D.features(D.instruments("all"), ["$volume"], start_time=pre_date, end_time=pre_date).squeeze()
        tradable.index = tradable.index.droplevel(level="datetime")
        tradable = tradable.reindex(universe).gt(0).values
        mask_force_hold = ~tradable

        # mask force sell
        mask_force_sell = np.array([stock in blacklist for stock in universe], dtype=bool)

        # optimize
        weight = self.optimizer(
            r=score,
            F=factor_exp,
            cov_b=factor_cov,
            var_u=specific_risk**2,
            w0=cur_weight,
            wb=bench_weight,
            mfh=mask_force_hold,
            mfs=mask_force_sell,
        )

        target_weight_position = {stock: weight for stock, weight in zip(universe, weight) if weight > 0}

        if self.verbose:
            self.logger.info("trade date: {:%Y-%m-%d}".format(trade_date))
            self.logger.info("number of holding stocks: {}".format(len(target_weight_position)))
            self.logger.info("total holding weight: {:.6f}".format(weight.sum()))

        return target_weight_position


In [168]:
from abc import abstractmethod
import copy
from qlib.backtest.position import BasePosition
from qlib.log import get_module_logger
from types import GeneratorType
from qlib.backtest.account import Account
import pandas as pd
from typing import List, Tuple, Union
from collections import defaultdict

from qlib.backtest.decision import Order, BaseTradeDecision
from qlib.backtest.exchange import Exchange
from qlib.backtest.utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx

from qlib.utils import init_instance_by_config
from qlib.strategy.base import BaseStrategy


class BaseExecutor:
    """Base executor for trading"""

    def __init__(
        self,
        time_per_step: str,
        start_time: Union[str, pd.Timestamp] = None,
        end_time: Union[str, pd.Timestamp] = None,
        indicator_config: dict = {},
        generate_portfolio_metrics: bool = False,
        verbose: bool = False,
        track_data: bool = False,
        trade_exchange: Exchange = None,
        common_infra: CommonInfrastructure = None,
        settle_type=BasePosition.ST_NO,
        **kwargs,
    ):
        """
        Parameters
        ----------
        time_per_step : str
            trade time per trading step, used for generate the trade calendar
        show_indicator: bool, optional
            whether to show indicators, :
            - 'pa', the price advantage
            - 'pos', the positive rate
            - 'ffr', the fulfill rate
        indicator_config: dict, optional
            config for calculating trade indicator, including the following fields:
            - 'show_indicator': whether to show indicators, optional, default by False. The indicators includes
                - 'pa', the price advantage
                - 'pos', the positive rate
                - 'ffr', the fulfill rate
            - 'pa_config': config for calculating price advantage(pa), optional
                - 'base_price': the based price than which the trading price is advanced, Optional, default by 'twap'
                    - If 'base_price' is 'twap', the based price is the time weighted average price
                    - If 'base_price' is 'vwap', the based price is the volume weighted average price
                - 'weight_method': weighted method when calculating total trading pa by different orders' pa in each step, optional, default by 'mean'
                    - If 'weight_method' is 'mean', calculating mean value of different orders' pa
                    - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' pa
                    - If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' pa
            - 'ffr_config': config for calculating fulfill rate(ffr), optional
                - 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each step, optional, default by 'mean'
                    - If 'weight_method' is 'mean', calculating mean value of different orders' ffr
                    - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' ffr
                    - If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' ffr
            Example:
                {
                    'show_indicator': True,
                    'pa_config': {
                        "agg": "twap",  # "vwap"
                        "price": "$close", # default to use deal price of the exchange
                    },
                    'ffr_config':{
                        'weight_method': 'value_weighted',
                    }
                }
        generate_portfolio_metrics : bool, optional
            whether to generate portfolio_metrics, by default False
        verbose : bool, optional
            whether to print trading info, by default False
        track_data : bool, optional
            whether to generate trade_decision, will be used when training rl agent
            - If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
            - Else,  `trade_decision` will not be generated

        trade_exchange : Exchange
            exchange that provides market info, used to generate portfolio_metrics
            - If generate_portfolio_metrics is None, trade_exchange will be ignored
            - Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra

        common_infra : CommonInfrastructure, optional:
            common infrastructure for backtesting, may including:
            - trade_account : Account, optional
                trade account for trading
            - trade_exchange : Exchange, optional
                exchange that provides market info

        settle_type : str
            Please refer to the docs of BasePosition.settle_start
        """
        self.time_per_step = time_per_step
        self.indicator_config = indicator_config
        self.generate_portfolio_metrics = generate_portfolio_metrics
        self.verbose = verbose
        self.track_data = track_data
        self._trade_exchange = trade_exchange
        self.level_infra = LevelInfrastructure()
        self.level_infra.reset_infra(common_infra=common_infra)
        self._settle_type = settle_type
        self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra)
        if common_infra is None:
            get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")

        # record deal order amount in one day
        self.dealt_order_amount = defaultdict(float)
        self.deal_day = None

    def reset_common_infra(self, common_infra, copy_trade_account=False):
        """
        reset infrastructure for trading
            - reset trade_account
        """
        if not hasattr(self, "common_infra"):
            self.common_infra = common_infra
        else:
            self.common_infra.update(common_infra)

        if common_infra.has("trade_account"):
            if copy_trade_account:
                # NOTE: there is a trick in the code.
                # shallow copy is used instead of deepcopy.
                # 1. So positions are shared
                # 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
                self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
            else:
                self.trade_account = common_infra.get("trade_account")
            self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)

    @property
    def trade_exchange(self) -> Exchange:
        """get trade exchange in a prioritized order"""
        return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange")

    @property
    def trade_calendar(self) -> TradeCalendarManager:
        """
        Though trade calendar can be accessed from multiple sources, but managing in a centralized way will make the
        code easier
        """
        return self.level_infra.get("trade_calendar")

    def reset(self, common_infra: CommonInfrastructure = None, **kwargs):
        """
        - reset `start_time` and `end_time`, used in trade calendar
        - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
        """

        if "start_time" in kwargs or "end_time" in kwargs:
            start_time = kwargs.get("start_time")
            end_time = kwargs.get("end_time")
            self.level_infra.reset_cal(freq=self.time_per_step, start_time=start_time, end_time=end_time)
        if common_infra is not None:
            self.reset_common_infra(common_infra)

    def get_level_infra(self):
        return self.level_infra

    def finished(self):
        return self.trade_calendar.finished()

    def execute(self, trade_decision: BaseTradeDecision, level: int = 0):
        """execute the trade decision and return the executed result

        NOTE: this function is never used directly in the framework. Should we delete it?

        Parameters
        ----------
        trade_decision : BaseTradeDecision

        level : int
            the level of current executor

        Returns
        ----------
        execute_result : List[object]
            the executed result for trade decision
        """
        return_value = {}
        for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
            pass
        return return_value.get("execute_result")

    @classmethod
    @abstractmethod
    def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
        """
        Please refer to the doc of collect_data
        The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
        collect_data

        Parameters
        ----------
        Please refer to the doc of collect_data


        Returns
        -------
        Tuple[List[object], dict]:
            (<the executed result for trade decision>, <the extra kwargs for `self.trade_account.update_bar_end`>)
        """

    def collect_data(
        self, trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0
    ) -> List[object]:
        """Generator for collecting the trade decision data for rl training

        his function will make a step forward

        Parameters
        ----------
        trade_decision : BaseTradeDecision

        level : int
            the level of current executor. 0 indicates the top level

        return_value : dict
            the mem address to return the value
            e.g.  {"return_value": <the executed result>}

        Returns
        ----------
        execute_result : List[object]
            the executed result for trade decision.
            ** NOTE!!!! **:
            1) This is necessary,  The return value of generator will be used in NestedExecutor
            2) Please note the executed results are not merged.

        Yields
        -------
        object
            trade decision
        """
        if self.track_data:
            yield trade_decision

        atomic = not issubclass(self.__class__, NestedExecutor)  #  issubclass(A, A) is True

        if atomic and trade_decision.get_range_limit(default_value=None) is not None:
            raise ValueError("atomic executor doesn't support specify `range_limit`")

        if self._settle_type != BasePosition.ST_NO:
            self.trade_account.current_position.settle_start(self._settle_type)

        obj = self._collect_data(trade_decision=trade_decision, level=level)

        if isinstance(obj, GeneratorType):
            res, kwargs = yield from obj
        else:
            # Some concrete executor don't have inner decisions
            res, kwargs = obj

        trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
        # Account will not be changed in this function
        self.trade_account.update_bar_end(
            trade_start_time,
            trade_end_time,
            self.trade_exchange,
            atomic=atomic,
            outer_trade_decision=trade_decision,
            indicator_config=self.indicator_config,
            **kwargs,
        )

        self.trade_calendar.step()

        if self._settle_type != BasePosition.ST_NO:
            self.trade_account.current_position.settle_commit()

        if return_value is not None:
            return_value.update({"execute_result": res})
        return res

    def get_all_executors(self):
        """get all executors"""
        return [self]


class NestedExecutor(BaseExecutor):
    """
    Nested Executor with inner strategy and executor
    - At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
    """

    def __init__(
        self,
        time_per_step: str,
        inner_executor: Union[BaseExecutor, dict],
        inner_strategy: Union[BaseStrategy, dict],
        start_time: Union[str, pd.Timestamp] = None,
        end_time: Union[str, pd.Timestamp] = None,
        indicator_config: dict = {},
        generate_portfolio_metrics: bool = False,
        verbose: bool = False,
        track_data: bool = False,
        skip_empty_decision: bool = True,
        align_range_limit: bool = True,
        common_infra: CommonInfrastructure = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        inner_executor : BaseExecutor
            trading env in each trading bar.
        inner_strategy : BaseStrategy
            trading strategy in each trading bar
        skip_empty_decision: bool
            Will the executor skip call inner loop when the decision is empty.
            It should be False in following cases
            - The decisions may be updated by steps
            - The inner executor may not follow the decisions from the outer strategy
        align_range_limit: bool
            force to align the trade_range decision
            It is only for nested executor, because range_limit is given by outer strategy
        """
        self.inner_executor: BaseExecutor = init_instance_by_config(
            inner_executor, common_infra=common_infra, accept_types=BaseExecutor
        )
        self.inner_strategy: BaseStrategy = init_instance_by_config(
            inner_strategy, common_infra=common_infra, accept_types=BaseStrategy
        )

        self._skip_empty_decision = skip_empty_decision
        self._align_range_limit = align_range_limit

        super(NestedExecutor, self).__init__(
            time_per_step=time_per_step,
            start_time=start_time,
            end_time=end_time,
            indicator_config=indicator_config,
            generate_portfolio_metrics=generate_portfolio_metrics,
            verbose=verbose,
            track_data=track_data,
            common_infra=common_infra,
            **kwargs,
        )

    def reset_common_infra(self, common_infra, copy_trade_account=False):
        """
        reset infrastructure for trading
            - reset inner_strategyand inner_executor common infra
        """
        # NOTE: please refer to the docs of BaseExecutor.reset_common_infra for the meaning of `copy_trade_account`

        # The first level follow the `copy_trade_account` from the upper level
        super(NestedExecutor, self).reset_common_infra(common_infra, copy_trade_account=copy_trade_account)

        # The lower level have to copy the trade_account
        self.inner_executor.reset_common_infra(common_infra, copy_trade_account=True)
        self.inner_strategy.reset_common_infra(common_infra)

    def _init_sub_trading(self, trade_decision):
        trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
        self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
        sub_level_infra = self.inner_executor.get_level_infra()
        self.level_infra.set_sub_level_infra(sub_level_infra)
        self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)

    def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision:
        # outer strategy have chance to update decision each iterator
        updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)
        if updated_trade_decision is not None:
            trade_decision = updated_trade_decision
            # NEW UPDATE
            # create a hook for inner strategy to update outer decision
            self.inner_strategy.alter_outer_trade_decision(trade_decision)
        return trade_decision

    def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
        execute_result = []
        inner_order_indicators = []
        decision_list = []
        # NOTE:
        # - this is necessary to calculating the steps in sub level
        # - more detailed information will be set into trade decision
        self._init_sub_trading(trade_decision)

        _inner_execute_result = None
        while not self.inner_executor.finished():
            trade_decision = self._update_trade_decision(trade_decision)

            if trade_decision.empty() and self._skip_empty_decision:
                # give one chance for outer strategy to update the strategy
                # - For updating some information in the sub executor(the strategy have no knowledge of the inner
                # executor when generating the decision)
                break

            sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar

            # NOTE: make sure get_start_end_idx is after `self._update_trade_decision`
            start_idx, end_idx = get_start_end_idx(sub_cal, trade_decision)
            if not self._align_range_limit or start_idx <= sub_cal.get_trade_step() <= end_idx:
                # if force align the range limit, skip the steps outside the decision range limit

                res = self.inner_strategy.generate_trade_decision(_inner_execute_result)

                # NOTE: !!!!!
                # the two lines below is for a special case in RL
                # To solve the confliction below
                # - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction loop
                #   For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=> (inner Qlib Executor)])
                # - However, RL-based framework has it's own script to run the loop
                #   For an _RL learning example_, (RL Policy) <=> (RL Env[(inner Qlib Executor)])
                # To make it possible to run  _nested qlib example_ and _RL learning example_ together, the solution below is proposed
                # - The entry script follow the example of  _RL learning example_ to be compatible with all kinds of RL Framework
                # - Each step of (RL Env) will make (inner Qlib Executor) one step forward
                #     - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env) by `yield from` and wait for the action from the policy
                # So the two lines below is the implementation of yielding control rights
                if isinstance(res, GeneratorType):
                    res = yield from res

                _inner_trade_decision: BaseTradeDecision = res

                trade_decision.mod_inner_decision(_inner_trade_decision)  # propagate part of decision information

                # NOTE sub_cal.get_step_time() must be called before collect_data in case of step shifting
                decision_list.append((_inner_trade_decision, *sub_cal.get_step_time()))

                # NOTE: Trade Calendar will step forward in the follow line
                _inner_execute_result = yield from self.inner_executor.collect_data(
                    trade_decision=_inner_trade_decision, level=level + 1
                )
                self.post_inner_exe_step(_inner_execute_result)
                execute_result.extend(_inner_execute_result)

                inner_order_indicators.append(
                    self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True)
                )
            else:
                # do nothing and just step forward
                sub_cal.step()

        return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}

    def post_inner_exe_step(self, inner_exe_res):
        """
        A hook for doing sth after each step of inner strategy

        Parameters
        ----------
        inner_exe_res :
            the execution result of inner task
        """

    def get_all_executors(self):
        """get all executors, including self and inner_executor.get_all_executors()"""
        return [self, *self.inner_executor.get_all_executors()]

In [169]:
level_infra = LevelInfrastructure()
print(type(level_infra))

<class 'qlib.backtest.utils.LevelInfrastructure'>


In [170]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import copy
from typing import List, Tuple, Union, TYPE_CHECKING

from qlib.backtest.account import Account

if TYPE_CHECKING:
    from qlib.strategy.base import BaseStrategy
    from qlib.backtest.executor import BaseExecutor
    from qlib.backtest.decision import BaseTradeDecision
from qlib.backtest.position import Position
from qlib.backtest.exchange import Exchange
from qlib.backtest.backtest import backtest_loop
from qlib.backtest.backtest import collect_data_loop
from qlib.backtest.utils import CommonInfrastructure
from qlib.backtest.decision import Order
from qlib.utils import init_instance_by_config
from qlib.log import get_module_logger
from qlib.config import C


def create_account_instance(
    start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position"
) -> Account:
    """
    # TODO: is very strange pass benchmark_config in the account(maybe for report)
    # There should be a post-step to process the report.

    Parameters
    ----------
    start_time
        start time of the benchmark
    end_time
        end time of the benchmark
    benchmark : str
        the benchmark for reporting
    account :   Union[
                    float,
                    {
                        "cash": float,
                        "stock1": Union[
                                        int,    # it is equal to {"amount": int}
                                        {"amount": int, "price"(optional): float},
                                  ]
                    },
                ]
        information for describing how to creating the account
        For `float`:
            Using Account with only initial cash
        For `dict`:
            key "cash" means initial cash.
            key "stock1" means the information of first stock with amount and price(optional).
            ...
    """
    if isinstance(account, (int, float)):
        pos_kwargs = {"init_cash": account}
    elif isinstance(account, dict):
        init_cash = account["cash"]
        del account["cash"]
        pos_kwargs = {
            "init_cash": init_cash,
            "position_dict": account,
        }
    else:
        raise ValueError("account must be in (int, float, Position)")

    kwargs = {
        "init_cash": account,
        "benchmark_config": {
            "benchmark": benchmark,
            "start_time": start_time,
            "end_time": end_time,
        },
        "pos_type": pos_type,
    }
    kwargs.update(pos_kwargs)
    return Account(**kwargs)

In [171]:
from qlib.backtest.utils import CommonInfrastructure
from qlib.strategy.base import BaseStrategy  # pylint: disable=C0415
from qlib.backtest.executor import BaseExecutor  # pylint: disable=C0415
from qlib.backtest import get_exchange


trade_account = create_account_instance(
        start_time=start_time, end_time='2020-07-31', benchmark=benchmark, account=1e9, pos_type="Position"
    )

exchange_kwargs: dict = {}
exchange_kwargs = copy.copy(exchange_kwargs)
if "start_time" not in exchange_kwargs:
    exchange_kwargs["start_time"] = start_time
if "end_time" not in exchange_kwargs:
    exchange_kwargs["end_time"] = '2020-07-31'
trade_exchange = get_exchange(**exchange_kwargs)

common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
print(type(common_infra))

[35932:MainThread](2022-08-08 17:13:25,572) INFO - qlib.backtest caller - [__init__.py:94] - Create new exchange


<class 'qlib.backtest.utils.CommonInfrastructure'>


In [172]:
common_infra.get_support_infra()

{'trade_account', 'trade_exchange'}

In [173]:
get_support_infra = level_infra.get_support_infra()
get_support_infra

{'common_infra', 'sub_level_infra', 'trade_calendar'}

In [174]:
level_infra.reset_cal("day", "2020-07-01", "2020-07-31")

In [175]:
level_infra.set_sub_level_infra(sub_level_infra = LevelInfrastructure())

In [176]:
sl = model.predict(dataset)
from qlib.contrib.strategy.signal_strategy import create_signal_from
sgl = create_signal_from(sl)
tk = WeekTopkDropoutStrategy(signal = sgl, risk_degree = 0.95,common_infra=common_infra, topk = 50, 
                         n_drop = 5, method_sell="bottom", method_buy="top", level_infra = level_infra,
                         hold_thresh=1, only_tradable=False)

In [177]:
pred_score = tk.generate_trade_decision()

            A_Week_HighAndLow         0
instrument                             
SH600000             0.058319 -0.041688
SH600004             0.035207 -0.039097
SH600009            -0.025850 -0.046769
SH600010             0.115170 -0.037071
SH600011             0.112623 -0.040753
...                       ...       ...
SZ300413             0.186556  0.007762
SZ300433             0.215624 -0.030400
SZ300498             0.139678 -0.020143
SZ300601             0.254703 -0.010046
SZ300628             0.041932 -0.031295

[300 rows x 2 columns]
instrument
SH600000    0.024983
SH600004    0.010439
SH600009   -0.032823
SH600010    0.064423
SH600011    0.061498
              ...   
SZ300413    0.126958
SZ300433    0.133616
SZ300498    0.086405
SZ300601    0.166453
SZ300628    0.017523
Name: value, Length: 300, dtype: float64


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  wha.loc[:, 'A_Week_HighAndLow'] = wha.apply(lambda x: x.sum(), axis=1).copy()
