In [None]:
import sys, os
import numpy as np
import pandas as pd
import qlib

In [None]:
from pathlib import Path
scripts_dir = Path("D:/py/qlib/scripts")
print(scripts_dir.joinpath("get_data.py"))
assert scripts_dir.joinpath("get_data.py").exists()

In [None]:
if not scripts_dir.joinpath("get_data.py").exists():
    # download get_data.py script
    scripts_dir = Path("~/tmp/qlib_code/scripts").expanduser().resolve()
    scripts_dir.mkdir(parents=True, exist_ok=True)
    import requests
    with requests.get("https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py") as resp:
        with open(scripts_dir.joinpath("get_data.py"), "wb") as fp:
            fp.write(resp.content)

In [None]:
from qlib.constant import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.utils import flatten_dict

In [None]:
provider_uri = "D:/py/qlib/.qlib/qlib_data/cn_data"  # target_dir
# if not exists_qlib_data(provider_uri):
#     print(f"Qlib data is not found in {provider_uri}")
#     sys.path.append(str(scripts_dir))
#     from get_data import GetData
#     GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)

In [None]:
market = "csi300"
benchmark = "SH000300"

In [None]:
from qlib.data import D
from qlib.data.filter import ExpressionDFilter
from qlib.data.filter import NameDFilter

In [None]:
instruments = D.instruments(market='csi300')
fields = ['$close', '(Ref($close, -1)-$close)/$close', '(Ref($close, -2)-Ref($close, -1))/Ref($close, -1)', '(Ref($close, -3)-Ref($close, -2))/Ref($close, -2)', '(Ref($close, -4)-Ref($close, -3))/Ref($close, -3)', '(Ref($close, -5)-Ref($close, -4))/Ref($close, -4)', '(Ref($close, -6)-Ref($close, -5))/Ref($close, -5)', '(Ref($close, -7)-Ref($close, -6))/Ref($close, -6)']
f_d = D.features(instruments, fields, start_time='2021-01-04', end_time='2021-06-11', freq='day')
df = f_d
df.index = df.index.get_level_values('datetime')
print(df.index.min(), df.index.max())

start_time = pd.to_datetime(df.index.min())
end_time = pd.to_datetime(df.index.max())
print(start_time.strftime('%Y-%m-%d'), end_time.strftime('%Y-%m-%d'))

In [None]:
experiment_name="online_srv"

In [None]:
###################################
# train model
###################################
# '2017-01-04', end_time='2022-02-28'
data_handler_config = {
    # "start_time": "2017-01-04",
    "start_time": start_time, # 
    "end_time": end_time,
    "fit_start_time": "2021-01-04",
    "fit_end_time": "2021-04-30",
    "instruments": market,
    "infer_processors": [
      {
        "class": "RobustZScoreNorm",
        "kwargs": {
          "fields_group": "feature",
          "clip_outlier": True
        }
      },
      {
        "class": "Fillna",
        "kwargs": {
          "fields_group": "feature"
        }
      }
    ],
    "learn_processors": [
      {
        "class": "DropnaLabel"
      },
      {
        "class": "CSRankNorm",
        "kwargs": {
          "fields_group": "label"
        }
      }
    ],
    "label": [
      "Ref($close, -2) / Ref($close, -1) - 1"
    ]
}

task = {   
    "model": {
        "class": "LSTM",
        "module_path": "qlib.contrib.model.pytorch_lstm",
        "kwargs": {
            "d_feat": 6,
            "hidden_size": 64,
            "num_layers": 2,
            "dropout": 0.1,
            "dec_dropout": 0.0,
            "n_epochs": 15,
            "lr": 1e-5,
            "early_stop": 3,
            "batch_size": 800,
            "metric": "loss",
            "loss": "mse",
            "optimizer": "adam",
            "GPU": 0
        },
    },
    "dataset": {
        "class": "DatasetH",
        "module_path": "qlib.data.dataset",
        "kwargs": {
            "handler": {
                "class": "Alpha360",
                "module_path": "qlib.contrib.data.handler",
                "kwargs": data_handler_config,
            },
            "segments": {
                "train": (start_time, "2021-04-30"),
                "valid": ("2021-05-01", "2021-05-19"),
                "test": ("2021-05-20", "2021-06-11"),
            },
        },
    },
}

# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])

In [None]:
# start exp to train model

experiment_id = 'cn_backtest'
# experiment_name: Optional[Text] = None,
# recorder_id: Optional[Text] = None,

# with R.start(experiment_name=experiment_name, experimen
# t_id=experiment_id):
with R.start(experiment_name=experiment_name):
    R.log_params(**flatten_dict(task))
    model.fit(dataset)
    R.save_objects(trained_model=model)
    rid = R.get_recorder().id
    # prediction
    recorder = R.get_recorder()
    sr = SignalRecord(model, dataset, recorder)
    sr.generate()

In [None]:
p = model.predict(dataset)

In [None]:
p[4030:4050]

In [None]:
sr.list()

In [None]:
sr.load('pred.pkl')

In [None]:
sr.load('label.pkl')

In [None]:
pr = PortAnaRecord(recorder)

In [None]:
pr.list()

In [None]:
###################################
# prediction, backtest & analysis
###################################
port_analysis_config = {
    "executor": {
        "class": "SimulatorExecutor",
        "module_path": "qlib.backtest.executor",
        "kwargs": {
            "time_per_step": "day",
            "generate_portfolio_metrics": True,
        },
    },
    "strategy": {
        "class": "TopkDropoutStrategy",
        "module_path": "qlib.contrib.strategy.signal_strategy",
        "kwargs": {
            "model": model,
            "dataset": dataset,
            "topk": 50,
            "n_drop": 5,
        },
    },
    "backtest": {
        "start_time": "2021-05-17",
        "end_time": "2021-06-01",
        "account": 100000000,
        "benchmark": benchmark,
        "exchange_kwargs": {
            "freq": "day",
            "limit_threshold": 0.095,
            "deal_price": "close",
            "open_cost": 0.0005,
            "close_cost": 0.0015,
            "min_cost": 5,
        },
    },
}

# backtest and analysis
with R.start(experiment_name="backtest_analysis"):
    recorder = R.get_recorder(recorder_id=rid, experiment_name="online_srv")
    model = recorder.load_object("trained_model")

    # prediction
    recorder = R.get_recorder()
    ba_rid = recorder.id
    sr = SignalRecord(model, dataset, recorder)
    sr.generate()

    # backtest & analysis
    par = PortAnaRecord(recorder, port_analysis_config, "day")
    par.generate()

In [None]:
from qlib.contrib.report import analysis_model, analysis_position
from qlib.data import D
recorder = R.get_recorder(recorder_id=ba_rid, experiment_name="online_srv")
print(recorder)
pred_df = recorder.load_object("pred.pkl")
pred_df_dates = pred_df.index.get_level_values(level='datetime')
report_normal_df = recorder.load_object("portfolio_analysis/report_normal_1day.pkl")
positions = recorder.load_object("portfolio_analysis/positions_normal_1day.pkl")
analysis_df = recorder.load_object("portfolio_analysis/port_analysis_1day.pkl")

In [None]:
pred_df

In [None]:
pred_df_dates

In [None]:
report_normal_df

In [None]:
analysis_df

In [None]:
#  Copyright (c) Microsoft Corporation.
#  Licensed under the MIT License.

import logging
import warnings
import pandas as pd
from pprint import pprint
from typing import Union, List, Optional

from qlib.utils.exceptions import LoadObjectError
from qlib.contrib.evaluate import risk_analysis, indicator_analysis

from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
from qlib.backtest import backtest as normal_backtest
from qlib.log import get_module_logger
from qlib.utils import fill_placeholder, flatten_dict, class_casting, get_date_by_shift
from qlib.utils.time import Freq
from qlib.utils.data import deepcopy_basic_type
from qlib.contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec


logger = get_module_logger("workflow", logging.INFO)

class RecordTemp:
    """
    This is the Records Template class that enables user to generate experiment results such as IC and
    backtest in a certain format.
    """

    artifact_path = None
    depend_cls = None  # the depend class of the record; the record will depend on the results generated by `depend_cls`

    @classmethod
    def get_path(cls, path=None):
        names = []
        if cls.artifact_path is not None:
            names.append(cls.artifact_path)

        if path is not None:
            names.append(path)

        return "/".join(names)

    def save(self, **kwargs):
        """
        It behaves the same as self.recorder.save_objects.
        But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
        """
        art_path = self.get_path()
        if art_path == "":
            art_path = None
        self.recorder.save_objects(artifact_path=art_path, **kwargs)

    def __init__(self, recorder):
        self._recorder = recorder

    @property
    def recorder(self):
        if self._recorder is None:
            raise ValueError("This RecordTemp did not set recorder yet.")
        return self._recorder

    def generate(self, **kwargs):
        """
        Generate certain records such as IC, backtest etc., and save them.

        Parameters
        ----------
        kwargs

        Return
        ------
        """
        raise NotImplementedError(f"Please implement the `generate` method.")

    def load(self, name: str, parents: bool = True):
        """
        It behaves the same as self.recorder.load_object.
        But it is an easier interface because users don't have to care about `get_path` and `artifact_path`

        Parameters
        ----------
        name : str
            the name for the file to be load.

        parents : bool
            Each recorder has different `artifact_path`.
            So parents recursively find the path in parents
            Sub classes has higher priority

        Return
        ------
        The stored records.
        """
        try:
            return self.recorder.load_object(self.get_path(name))
        except LoadObjectError:
            if parents:
                if self.depend_cls is not None:
                    with class_casting(self, self.depend_cls):
                        return self.load(name, parents=True)

    def list(self):
        """
        List the supported artifacts.
        Users don't have to consider self.get_path

        Return
        ------
        A list of all the supported artifacts.
        """
        return []

    def check(self, include_self: bool = False, parents: bool = True):
        """
        Check if the records is properly generated and saved.
        It is useful in following examples
        - checking if the depended files complete before generating new things.
        - checking if the final files is completed

        Parameters
        ----------
        include_self : bool
            is the file generated by self included
        parents : bool
            will we check parents

        Raise
        ------
        FileNotFoundError
        : whether the records are stored properly.
        """
        if include_self:

            # Some mlflow backend will not list the directly recursively.
            # So we force to the directly
            artifacts = {}

            def _get_arts(dirn):
                if dirn not in artifacts:
                    artifacts[dirn] = self.recorder.list_artifacts(dirn)
                return artifacts[dirn]

            for item in self.list():
                ps = self.get_path(item).split("/")
                dirn, fn = "/".join(ps[:-1]), ps[-1]
                if self.get_path(item) not in _get_arts(dirn):
                    raise FileNotFoundError
        if parents:
            if self.depend_cls is not None:
                with class_casting(self, self.depend_cls):
                    self.check(include_self=True)


class ACRecordTemp(RecordTemp):
    """Automatically checking record template"""

    def __init__(self, recorder, skip_existing=False):
        self.skip_existing = skip_existing
        super().__init__(recorder=recorder)

    def generate(self, *args, **kwargs):
        """automatically checking the files and then run the concrete generating task"""
        if self.skip_existing:
            try:
                self.check(include_self=True, parents=False)
            except FileNotFoundError:
                pass  # continue to generating metrics
            else:
                logger.info("The results has previously generated, Generation skipped.")
                return

        try:
            self.check()
        except FileNotFoundError:
            logger.warning("The dependent data does not exists. Generation skipped.")
            return
        return self._generate(*args, **kwargs)

    def _generate(self, *args, **kwargs):
        raise NotImplementedError(f"Please implement the `_generate` method")


class PortAnaRecord(ACRecordTemp):
    """
    This is the Portfolio Analysis Record class that generates the analysis results such as those of backtest. This class inherits the ``RecordTemp`` class.

    The following files will be stored in recorder
    - report_normal.pkl & positions_normal.pkl:
        - The return report and detailed positions of the backtest, returned by `qlib/contrib/evaluate.py:backtest`
    - port_analysis.pkl : The risk analysis of your portfolio, returned by `qlib/contrib/evaluate.py:risk_analysis`
    """

    artifact_path = "portfolio_analysis"
    depend_cls = SignalRecord

    def __init__(
        self,
        recorder,
        config=None,
        risk_analysis_freq: Union[List, str] = None,
        indicator_analysis_freq: Union[List, str] = None,
        indicator_analysis_method=None,
        skip_existing=False,
        **kwargs,
    ):
        """
        config["strategy"] : dict
            define the strategy class as well as the kwargs.
        config["executor"] : dict
            define the executor class as well as the kwargs.
        config["backtest"] : dict
            define the backtest kwargs.
        risk_analysis_freq : str|List[str]
            risk analysis freq of report
        indicator_analysis_freq : str|List[str]
            indicator analysis freq of report
        indicator_analysis_method : str, optional, default by None
            the candidated values include 'mean', 'amount_weighted', 'value_weighted'
        """
        super().__init__(recorder=recorder, skip_existing=skip_existing, **kwargs)

        if config is None:
            config = {  # Default config for daily trading
                "strategy": {
                    "class": "TopkDropoutStrategy",
                    "module_path": "qlib.contrib.strategy",
                    "kwargs": {"signal": "<PRED>", "topk": 50, "n_drop": 5},
                },
                "backtest": {
                    "start_time": None,
                    "end_time": None,
                    "account": 100000000,
                    "benchmark": "SH000300",
                    "exchange_kwargs": {
                        "limit_threshold": 0.095,
                        "deal_price": "close",
                        "open_cost": 0.0005,
                        "close_cost": 0.0015,
                        "min_cost": 5,
                    },
                },
            }
        # We only deepcopy_basic_type because
        # - We don't want to affect the config outside.
        # - We don't want to deepcopy complex object to avoid overhead
        config = deepcopy_basic_type(config)

        self.strategy_config = config["strategy"]
        _default_executor_config = {
            "class": "SimulatorExecutor",
            "module_path": "qlib.backtest.executor",
            "kwargs": {
                "time_per_step": "day",
                "generate_portfolio_metrics": True,
            },
        }
        self.executor_config = config.get("executor", _default_executor_config)
        self.backtest_config = config["backtest"]

        self.all_freq = self._get_report_freq(self.executor_config)
        if risk_analysis_freq is None:
            risk_analysis_freq = [self.all_freq[0]]
        if indicator_analysis_freq is None:
            indicator_analysis_freq = [self.all_freq[0]]

        if isinstance(risk_analysis_freq, str):
            risk_analysis_freq = [risk_analysis_freq]
        if isinstance(indicator_analysis_freq, str):
            indicator_analysis_freq = [indicator_analysis_freq]

        self.risk_analysis_freq = [
            "{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in risk_analysis_freq
        ]
        self.indicator_analysis_freq = [
            "{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in indicator_analysis_freq
        ]
        self.indicator_analysis_method = indicator_analysis_method

    def _get_report_freq(self, executor_config):
        ret_freq = []
        if executor_config["kwargs"].get("generate_portfolio_metrics", False):
            _count, _freq = Freq.parse(executor_config["kwargs"]["time_per_step"])
            ret_freq.append(f"{_count}{_freq}")
        if "inner_executor" in executor_config["kwargs"]:
            ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["inner_executor"]))
        return ret_freq

    def _generate(self, **kwargs):
        pred = self.load("pred.pkl")

        # replace the "<PRED>" with prediction saved before
        placehorder_value = {"<PRED>": pred}
        for k in "executor_config", "strategy_config":
            setattr(self, k, fill_placeholder(getattr(self, k), placehorder_value))

        # if the backtesting time range is not set, it will automatically extract time range from the prediction file
        dt_values = pred.index.get_level_values("datetime")
        if self.backtest_config["start_time"] is None:
            self.backtest_config["start_time"] = dt_values.min()
        if self.backtest_config["end_time"] is None:
            self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1)

        # custom strategy and get backtest
        portfolio_metric_dict, indicator_dict = normal_backtest(
            executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
        )
        for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
            self.save(**{f"report_normal_{_freq}.pkl": report_normal})
            self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})

        for _freq, indicators_normal in indicator_dict.items():
            self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal})

        for _analysis_freq in self.risk_analysis_freq:
            if _analysis_freq not in portfolio_metric_dict:
                warnings.warn(
                    f"the freq {_analysis_freq} report is not found, please set the corresponding env with `generate_portfolio_metrics=True`"
                )
            else:
                report_normal, _ = portfolio_metric_dict.get(_analysis_freq)
                analysis = dict()
                analysis["excess_return_without_cost"] = risk_analysis(
                    report_normal["return"] - report_normal["bench"], freq=_analysis_freq
                )
                analysis["excess_return_with_cost"] = risk_analysis(
                    report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=_analysis_freq
                )

                analysis_df = pd.concat(analysis)  # type: pd.DataFrame
                # log metrics
                analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
                self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
                # save results
                self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
                logger.info(
                    f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
                )
                # print out results
                pprint(f"The following are analysis results of benchmark return({_analysis_freq}).")
                pprint(risk_analysis(report_normal["bench"], freq=_analysis_freq))
                pprint(f"The following are analysis results of the excess return without cost({_analysis_freq}).")
                pprint(analysis["excess_return_without_cost"])
                pprint(f"The following are analysis results of the excess return with cost({_analysis_freq}).")
                pprint(analysis["excess_return_with_cost"])

        for _analysis_freq in self.indicator_analysis_freq:
            if _analysis_freq not in indicator_dict:
                warnings.warn(f"the freq {_analysis_freq} indicator is not found")
            else:
                indicators_normal = indicator_dict.get(_analysis_freq)
                if self.indicator_analysis_method is None:
                    analysis_df = indicator_analysis(indicators_normal)
                else:
                    analysis_df = indicator_analysis(indicators_normal, method=self.indicator_analysis_method)
                # log metrics
                analysis_dict = analysis_df["value"].to_dict()
                self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
                # save results
                self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
                logger.info(
                    f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
                )
                pprint(f"The following are analysis results of indicators({_analysis_freq}).")
                pprint(analysis_df)

    def list(self):
        list_path = []
        for _freq in self.all_freq:
            list_path.extend(
                [
                    f"report_normal_{_freq}.pkl",
                    f"positions_normal_{_freq}.pkl",
                ]
            )
        for _analysis_freq in self.risk_analysis_freq:
            if _analysis_freq in self.all_freq:
                list_path.append(f"port_analysis_{_analysis_freq}.pkl")
            else:
                warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")

        for _analysis_freq in self.indicator_analysis_freq:
            if _analysis_freq in self.all_freq:
                list_path.append(f"indicator_analysis_{_analysis_freq}.pkl")
            else:
                warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
        return list_path


In [None]:
###################################
# prediction, backtest & analysis
###################################
port_analysis_config = {
    "executor": {
        "class": "SimulatorExecutor",
        "module_path": "qlib.backtest.executor",
        "kwargs": {
            "time_per_step": "day",
            "generate_portfolio_metrics": True,
        },
    },
    "strategy": {
        "class": "WeekTopkDropoutStrategy",
        "module_path": "qlib.contrib.strategy.signal_strategy",
        "kwargs": {
            "model": model,
            "dataset": dataset,
            "topk": 50,
            "n_drop": 5,
        },
    },
    "backtest": {
        "start_time": "2021-05-20",
        "end_time": "2021-06-01",
        "account": 100000000,
        "benchmark": benchmark,
        "exchange_kwargs": {
            "freq": "day",
            "limit_threshold": 0.095,
            "deal_price": "close",
            "open_cost": 0.0005,
            "close_cost": 0.0015,
            "min_cost": 5,
        },
    },
}

# backtest and analysis
with R.start(experiment_name="backtest_analysis"):
    recorder = R.get_recorder(recorder_id=rid, experiment_name="online_srv")
    model = recorder.load_object("trained_model")

    # prediction
    recorder = R.get_recorder()
    ba_rid = recorder.id
    sr = SignalRecord(model, dataset, recorder)
    sr.generate()

    # backtest & analysis
    par = PortAnaRecord(recorder, port_analysis_config, "day")
    par.generate()

In [None]:
import os
from qlib.backtest.signal import Signal, create_signal_from

import copy
import warnings
import numpy as np
import pandas as pd

from typing import Dict, List, Text, Tuple, Union

from qlib.data import D
from qlib.data.dataset import Dataset
from qlib.model.base import BaseModel
from qlib.strategy.base import BaseStrategy
from qlib.backtest.position import Position
from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO
from qlib.log import get_module_logger
from qlib.utils import get_pre_trading_date, load_dataset
from qlib.contrib.strategy.order_generator import OrderGenWOInteract
from qlib.contrib.strategy.optimizer import EnhancedIndexingOptimizer

In [None]:
class BaseSignalStrategy(BaseStrategy):
    def __init__(
        self,
        *,
        signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] = None,
        model=None,
        dataset=None,
        risk_degree: float = 0.95,
        trade_exchange=None,
        level_infra=None,
        common_infra=None,
        **kwargs,
    ):
        """
        Parameters
        -----------
        signal :
            the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
            the decision of the strategy will base on the given signal
        risk_degree : float
            position percentage of total value.
        trade_exchange : Exchange
            exchange that provides market info, used to deal order and generate report
            - If `trade_exchange` is None, self.trade_exchange will be set with common_infra
            - It allowes different trade_exchanges is used in different executions.
            - For example:
                - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
                - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.

        """
        super().__init__(level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs)

        self.risk_degree = risk_degree

        # This is trying to be compatible with previous version of qlib task config
        if model is not None and dataset is not None:
            warnings.warn("`model` `dataset` is deprecated; use `signal`.", DeprecationWarning)
            signal = model, dataset

        self.signal: Signal = create_signal_from(signal)

    def get_risk_degree(self, trade_step=None):
        """get_risk_degree
        Return the proportion of your total value you will used in investment.
        Dynamically risk_degree will result in Market timing.
        """
        # It will use 95% amount of your total value by default
        return self.risk_degree

In [None]:
import time


class WeekTopkDropoutStrategy(BaseSignalStrategy):
    # TODO:
    # 1. Supporting leverage the get_range_limit result from the decision
    # 2. Supporting alter_outer_trade_decision
    # 3. Supporting checking the availability of trade decision
    def __init__(
            self,
            *,
            topk,
            n_drop,
            method_sell="bottom",
            method_buy="top",
            hold_thresh=1,
            only_tradable=False,
            **kwargs,
    ):
        """
        Parameters
        -----------
        topk : int
            the number of stocks in the portfolio.
        n_drop : int
            number of stocks to be replaced in each trading date.
        method_sell : str
            dropout method_sell, random/bottom.
        method_buy : str
            dropout method_buy, random/top.
        hold_thresh : int
            minimum holding days
            before sell stock , will check current.get_stock_count(order.stock_id) >= self.hold_thresh.
        only_tradable : bool
            will the strategy only consider the tradable stock when buying and selling.
            if only_tradable:
                strategy will make buy sell decision without checking the tradable state of the stock.
            else:
                strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
        """
        super().__init__(**kwargs)
        self.topk = topk
        self.n_drop = n_drop
        self.method_sell = method_sell
        self.method_buy = method_buy
        self.hold_thresh = hold_thresh
        self.only_tradable = only_tradable

    

    def generate_trade_decision(self, execute_result=None):
        # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
        trade_step = self.trade_calendar.get_trade_step()
        trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
        pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=-1)
        pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
        print(trade_start_time, trade_end_time, pred_start_time, pred_end_time, trade_step)
        print(pred_score)

        def get_week_HighAndLow(datetime, instruments):
            # 计算一周后的涨跌幅
            # instruments = D.instruments(market='all')
            datetime = datetime.strftime('%Y-%m-%d')
            fields = ['(Ref($close, -1)-$close)/$close', '(Ref($close, -2)-Ref($close, -1))/Ref($close, -1)',
                      '(Ref($close, -3)-Ref($close, -2))/Ref($close, -2)',
                      '(Ref($close, -4)-Ref($close, -3))/Ref($close, -3)',
                      '(Ref($close, -5)-Ref($close, -4))/Ref($close, -4)',
                      '(Ref($close, -6)-Ref($close, -5))/Ref($close, -5)',
                      '(Ref($close, -7)-Ref($close, -6))/Ref($close, -6)']
            f_d = D.features(instruments, fields, start_time='2021-05-20', end_time='2021-06-11', freq='day')
            wha = f_d.loc[(slice(None), datetime), :]
            wha['A_Week_HighAndLow'] = wha.apply(lambda x: x.sum(), axis=1)
            wha = wha.loc[:, 'A_Week_HighAndLow'].droplevel('datetime')
            return wha

        wha = get_week_HighAndLow(datetime=pred_start_time, instruments=D.instruments("all"))
        
        print(wha)
        pred = pd.concat([wha, pred_score], axis=1)
        pred.columns = ['A_Week_HighAndLow', 'score']

        def compute_score(a, b):
            # a是wha，b是pred_score
            return a * (2 / 3) + b * (1 / 3)

        pred['value'] = pred.apply(lambda row: compute_score(row['A_Week_HighAndLow'], row['score']), axis=1)
        pred_score = pred['value']

        # NOTE: the current version of topk dropout strategy can't handle pd.DataFrame(multiple signal)
        # So it only leverage the first col of signal

        if isinstance(pred_score, pd.DataFrame):
            pred_score = pred_score.iloc[:, 0]
        if pred_score is None:
            print("pred_score is None")
            # print(trade_step, trade_start_time, trade_end_time)
            # print(pred_start_time, pred_end_time)
            return TradeDecisionWO([], self)
        if self.only_tradable:
            # If The strategy only consider tradable stock when make decision
            # It needs following actions to filter stocks
            def get_first_n(l, n, reverse=False):
                cur_n = 0
                res = []
                for si in reversed(l) if reverse else l:
                    if self.trade_exchange.is_stock_tradable(
                            stock_id=si, start_time=trade_start_time, end_time=trade_end_time
                    ):
                        res.append(si)
                        cur_n += 1
                        if cur_n >= n:
                            break
                return res[::-1] if reverse else res

            def get_last_n(l, n):
                return get_first_n(l, n, reverse=True)

            def filter_stock(l):
                return [
                    si
                    for si in l
                    if self.trade_exchange.is_stock_tradable(
                        stock_id=si, start_time=trade_start_time, end_time=trade_end_time
                    )
                ]

        else:
            # Otherwise, the stock will make decision with out the stock tradable info
            def get_first_n(l, n):
                return list(l)[:n]

            def get_last_n(l, n):
                return list(l)[-n:]

            def filter_stock(l):
                return l

        current_temp = copy.deepcopy(self.trade_position)
        # generate order list for this adjust date
        sell_order_list = []
        buy_order_list = []
        # load score
        cash = current_temp.get_cash()
        current_stock_list = current_temp.get_stock_list()
        # last position (sorted by score)
        last = pred_score.reindex(current_stock_list).sort_values(ascending=False).index
        # The new stocks today want to buy **at most**
        if self.method_buy == "top":
            today = get_first_n(
                pred_score[~pred_score.index.isin(last)].sort_values(ascending=False).index,
                self.n_drop + self.topk - len(last),
            )
        elif self.method_buy == "random":
            topk_candi = get_first_n(pred_score.sort_values(ascending=False).index, self.topk)
            candi = list(filter(lambda x: x not in last, topk_candi))
            n = self.n_drop + self.topk - len(last)
            try:
                today = np.random.choice(candi, n, replace=False)
            except ValueError:
                today = candi
        else:
            raise NotImplementedError(f"This type of input is not supported")
        # combine(new stocks + last stocks),  we will drop stocks from this list
        # In case of dropping higher score stock and buying lower score stock.
        comb = pred_score.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index

        # Get the stock list we really want to sell (After filtering the case that we sell high and buy low)
        if self.method_sell == "bottom":
            sell = last[last.isin(get_last_n(comb, self.n_drop))]
        elif self.method_sell == "random":
            candi = filter_stock(last)
            try:
                sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])
            except ValueError:  # No enough candidates
                sell = candi
        else:
            raise NotImplementedError(f"This type of input is not supported")

        # Get the stock list we really want to buy
        buy = today[: len(sell) + self.topk - len(last)]
        for code in current_stock_list:
            if not self.trade_exchange.is_stock_tradable(
                    stock_id=code, start_time=trade_start_time, end_time=trade_end_time
            ):
                continue
            if code in sell:
                # check hold limit
                time_per_step = self.trade_calendar.get_freq()
                if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh:
                    continue
                # sell order
                sell_amount = current_temp.get_stock_amount(code=code)
                factor = self.trade_exchange.get_factor(
                    stock_id=code, start_time=trade_start_time, end_time=trade_end_time
                )
                # sell_amount = self.trade_exchange.round_amount_by_trade_unit(sell_amount, factor)
                sell_order = Order(
                    stock_id=code,
                    amount=sell_amount,
                    start_time=trade_start_time,
                    end_time=trade_end_time,
                    direction=Order.SELL,  # 0 for sell, 1 for buy
                )
                # is order executable
                if self.trade_exchange.check_order(sell_order):
                    sell_order_list.append(sell_order)
                    trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
                        sell_order, position=current_temp
                    )
                    # update cash
                    cash += trade_val - trade_cost
        # buy new stock
        # note the current has been changed
        current_stock_list = current_temp.get_stock_list()
        value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0

        # open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
        # consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
        # value = value / (1+self.trade_exchange.open_cost) # set open_cost limit
        for code in buy:
            # check is stock suspended
            if not self.trade_exchange.is_stock_tradable(
                    stock_id=code, start_time=trade_start_time, end_time=trade_end_time
            ):
                continue
            # buy order
            buy_price = self.trade_exchange.get_deal_price(
                stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY
            )
            buy_amount = value / buy_price
            factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
            buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
            buy_order = Order(
                stock_id=code,
                amount=buy_amount,
                start_time=trade_start_time,
                end_time=trade_end_time,
                direction=Order.BUY,  # 1 for buy
            )
            buy_order_list.append(buy_order)
        return TradeDecisionWO(sell_order_list + buy_order_list, self)

In [None]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import copy
from typing import List, Tuple, Union, TYPE_CHECKING

from qlib.backtest.account import Account

if TYPE_CHECKING:
    from qlib.strategy.base import BaseStrategy
    from qlib.backtest.executor import BaseExecutor
    from qlib.backtest.decision import BaseTradeDecision
from qlib.backtest.position import Position
from qlib.backtest.exchange import Exchange
from qlib.backtest.backtest import backtest_loop
from qlib.backtest.backtest import collect_data_loop
from qlib.backtest.utils import CommonInfrastructure
from qlib.backtest.decision import Order
from qlib.utils import init_instance_by_config
from qlib.log import get_module_logger
from qlib.config import C


def create_account_instance(
    start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position"
) -> Account:
    """
    # TODO: is very strange pass benchmark_config in the account(maybe for report)
    # There should be a post-step to process the report.

    Parameters
    ----------
    start_time
        start time of the benchmark
    end_time
        end time of the benchmark
    benchmark : str
        the benchmark for reporting
    account :   Union[
                    float,
                    {
                        "cash": float,
                        "stock1": Union[
                                        int,    # it is equal to {"amount": int}
                                        {"amount": int, "price"(optional): float},
                                  ]
                    },
                ]
        information for describing how to creating the account
        For `float`:
            Using Account with only initial cash
        For `dict`:
            key "cash" means initial cash.
            key "stock1" means the information of first stock with amount and price(optional).
            ...
    """
    if isinstance(account, (int, float)):
        pos_kwargs = {"init_cash": account}
    elif isinstance(account, dict):
        init_cash = account["cash"]
        del account["cash"]
        pos_kwargs = {
            "init_cash": init_cash,
            "position_dict": account,
        }
    else:
        raise ValueError("account must be in (int, float, Position)")

    kwargs = {
        "init_cash": account,
        "benchmark_config": {
            "benchmark": benchmark,
            "start_time": start_time,
            "end_time": end_time,
        },
        "pos_type": pos_type,
    }
    kwargs.update(pos_kwargs)
    return Account(**kwargs)

In [None]:
from qlib.backtest.utils import CommonInfrastructure
from qlib.strategy.base import BaseStrategy  # pylint: disable=C0415
from qlib.backtest.executor import BaseExecutor  # pylint: disable=C0415
from qlib.backtest import get_exchange


trade_account = create_account_instance(
        start_time=start_time, end_time='2021-06-11', benchmark=benchmark, account=1e9, pos_type="Position"
    )

exchange_kwargs: dict = {}
exchange_kwargs = copy.copy(exchange_kwargs)
if "start_time" not in exchange_kwargs:
    exchange_kwargs["start_time"] = start_time
if "end_time" not in exchange_kwargs:
    exchange_kwargs["end_time"] = '2021-06-11'
trade_exchange = get_exchange(**exchange_kwargs)

common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
print(type(common_infra))

In [None]:
from abc import abstractmethod
import copy
from qlib.backtest.position import BasePosition
from qlib.log import get_module_logger
from types import GeneratorType
from qlib.backtest.account import Account
import pandas as pd
from typing import List, Tuple, Union
from collections import defaultdict

from qlib.backtest.decision import Order, BaseTradeDecision
from qlib.backtest.exchange import Exchange
from qlib.backtest.utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx

from qlib.utils import init_instance_by_config
from qlib.strategy.base import BaseStrategy


class BaseExecutor:
    """Base executor for trading"""

    def __init__(
        self,
        time_per_step: str,
        start_time: Union[str, pd.Timestamp] = None,
        end_time: Union[str, pd.Timestamp] = None,
        indicator_config: dict = {},
        generate_portfolio_metrics: bool = False,
        verbose: bool = False,
        track_data: bool = False,
        trade_exchange: Exchange = None,
        common_infra: CommonInfrastructure = None,
        settle_type=BasePosition.ST_NO,
        **kwargs,
    ):
        """
        Parameters
        ----------
        time_per_step : str
            trade time per trading step, used for generate the trade calendar
        show_indicator: bool, optional
            whether to show indicators, :
            - 'pa', the price advantage
            - 'pos', the positive rate
            - 'ffr', the fulfill rate
        indicator_config: dict, optional
            config for calculating trade indicator, including the following fields:
            - 'show_indicator': whether to show indicators, optional, default by False. The indicators includes
                - 'pa', the price advantage
                - 'pos', the positive rate
                - 'ffr', the fulfill rate
            - 'pa_config': config for calculating price advantage(pa), optional
                - 'base_price': the based price than which the trading price is advanced, Optional, default by 'twap'
                    - If 'base_price' is 'twap', the based price is the time weighted average price
                    - If 'base_price' is 'vwap', the based price is the volume weighted average price
                - 'weight_method': weighted method when calculating total trading pa by different orders' pa in each step, optional, default by 'mean'
                    - If 'weight_method' is 'mean', calculating mean value of different orders' pa
                    - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' pa
                    - If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' pa
            - 'ffr_config': config for calculating fulfill rate(ffr), optional
                - 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each step, optional, default by 'mean'
                    - If 'weight_method' is 'mean', calculating mean value of different orders' ffr
                    - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' ffr
                    - If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' ffr
            Example:
                {
                    'show_indicator': True,
                    'pa_config': {
                        "agg": "twap",  # "vwap"
                        "price": "$close", # default to use deal price of the exchange
                    },
                    'ffr_config':{
                        'weight_method': 'value_weighted',
                    }
                }
        generate_portfolio_metrics : bool, optional
            whether to generate portfolio_metrics, by default False
        verbose : bool, optional
            whether to print trading info, by default False
        track_data : bool, optional
            whether to generate trade_decision, will be used when training rl agent
            - If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
            - Else,  `trade_decision` will not be generated

        trade_exchange : Exchange
            exchange that provides market info, used to generate portfolio_metrics
            - If generate_portfolio_metrics is None, trade_exchange will be ignored
            - Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra

        common_infra : CommonInfrastructure, optional:
            common infrastructure for backtesting, may including:
            - trade_account : Account, optional
                trade account for trading
            - trade_exchange : Exchange, optional
                exchange that provides market info

        settle_type : str
            Please refer to the docs of BasePosition.settle_start
        """
        self.time_per_step = time_per_step
        self.indicator_config = indicator_config
        self.generate_portfolio_metrics = generate_portfolio_metrics
        self.verbose = verbose
        self.track_data = track_data
        self._trade_exchange = trade_exchange
        self.level_infra = LevelInfrastructure()
        self.level_infra.reset_infra(common_infra=common_infra)
        self._settle_type = settle_type
        self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra)
        if common_infra is None:
            get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")

        # record deal order amount in one day
        self.dealt_order_amount = defaultdict(float)
        self.deal_day = None

    def reset_common_infra(self, common_infra, copy_trade_account=False):
        """
        reset infrastructure for trading
            - reset trade_account
        """
        if not hasattr(self, "common_infra"):
            self.common_infra = common_infra
        else:
            self.common_infra.update(common_infra)

        if common_infra.has("trade_account"):
            if copy_trade_account:
                # NOTE: there is a trick in the code.
                # shallow copy is used instead of deepcopy.
                # 1. So positions are shared
                # 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
                self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
            else:
                self.trade_account = common_infra.get("trade_account")
            self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)

    @property
    def trade_exchange(self) -> Exchange:
        """get trade exchange in a prioritized order"""
        return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange")

    @property
    def trade_calendar(self) -> TradeCalendarManager:
        """
        Though trade calendar can be accessed from multiple sources, but managing in a centralized way will make the
        code easier
        """
        return self.level_infra.get("trade_calendar")

    def reset(self, common_infra: CommonInfrastructure = None, **kwargs):
        """
        - reset `start_time` and `end_time`, used in trade calendar
        - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
        """

        if "start_time" in kwargs or "end_time" in kwargs:
            start_time = kwargs.get("start_time")
            end_time = kwargs.get("end_time")
            self.level_infra.reset_cal(freq=self.time_per_step, start_time=start_time, end_time=end_time)
        if common_infra is not None:
            self.reset_common_infra(common_infra)

    def get_level_infra(self):
        return self.level_infra

    def finished(self):
        return self.trade_calendar.finished()

    def execute(self, trade_decision: BaseTradeDecision, level: int = 0):
        """execute the trade decision and return the executed result

        NOTE: this function is never used directly in the framework. Should we delete it?

        Parameters
        ----------
        trade_decision : BaseTradeDecision

        level : int
            the level of current executor

        Returns
        ----------
        execute_result : List[object]
            the executed result for trade decision
        """
        return_value = {}
        for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
            pass
        return return_value.get("execute_result")

    @classmethod
    @abstractmethod
    def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
        """
        Please refer to the doc of collect_data
        The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
        collect_data

        Parameters
        ----------
        Please refer to the doc of collect_data


        Returns
        -------
        Tuple[List[object], dict]:
            (<the executed result for trade decision>, <the extra kwargs for `self.trade_account.update_bar_end`>)
        """

    def collect_data(
        self, trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0
    ) -> List[object]:
        """Generator for collecting the trade decision data for rl training

        his function will make a step forward

        Parameters
        ----------
        trade_decision : BaseTradeDecision

        level : int
            the level of current executor. 0 indicates the top level

        return_value : dict
            the mem address to return the value
            e.g.  {"return_value": <the executed result>}

        Returns
        ----------
        execute_result : List[object]
            the executed result for trade decision.
            ** NOTE!!!! **:
            1) This is necessary,  The return value of generator will be used in NestedExecutor
            2) Please note the executed results are not merged.

        Yields
        -------
        object
            trade decision
        """
        if self.track_data:
            yield trade_decision

        atomic = not issubclass(self.__class__, NestedExecutor)  #  issubclass(A, A) is True

        if atomic and trade_decision.get_range_limit(default_value=None) is not None:
            raise ValueError("atomic executor doesn't support specify `range_limit`")

        if self._settle_type != BasePosition.ST_NO:
            self.trade_account.current_position.settle_start(self._settle_type)

        obj = self._collect_data(trade_decision=trade_decision, level=level)

        if isinstance(obj, GeneratorType):
            res, kwargs = yield from obj
        else:
            # Some concrete executor don't have inner decisions
            res, kwargs = obj

        trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
        # Account will not be changed in this function
        self.trade_account.update_bar_end(
            trade_start_time,
            trade_end_time,
            self.trade_exchange,
            atomic=atomic,
            outer_trade_decision=trade_decision,
            indicator_config=self.indicator_config,
            **kwargs,
        )

        self.trade_calendar.step()

        if self._settle_type != BasePosition.ST_NO:
            self.trade_account.current_position.settle_commit()

        if return_value is not None:
            return_value.update({"execute_result": res})
        return res

    def get_all_executors(self):
        """get all executors"""
        return [self]


class NestedExecutor(BaseExecutor):
    """
    Nested Executor with inner strategy and executor
    - At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
    """

    def __init__(
        self,
        time_per_step: str,
        inner_executor: Union[BaseExecutor, dict],
        inner_strategy: Union[BaseStrategy, dict],
        start_time: Union[str, pd.Timestamp] = None,
        end_time: Union[str, pd.Timestamp] = None,
        indicator_config: dict = {},
        generate_portfolio_metrics: bool = False,
        verbose: bool = False,
        track_data: bool = False,
        skip_empty_decision: bool = True,
        align_range_limit: bool = True,
        common_infra: CommonInfrastructure = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        inner_executor : BaseExecutor
            trading env in each trading bar.
        inner_strategy : BaseStrategy
            trading strategy in each trading bar
        skip_empty_decision: bool
            Will the executor skip call inner loop when the decision is empty.
            It should be False in following cases
            - The decisions may be updated by steps
            - The inner executor may not follow the decisions from the outer strategy
        align_range_limit: bool
            force to align the trade_range decision
            It is only for nested executor, because range_limit is given by outer strategy
        """
        self.inner_executor: BaseExecutor = init_instance_by_config(
            inner_executor, common_infra=common_infra, accept_types=BaseExecutor
        )
        self.inner_strategy: BaseStrategy = init_instance_by_config(
            inner_strategy, common_infra=common_infra, accept_types=BaseStrategy
        )

        self._skip_empty_decision = skip_empty_decision
        self._align_range_limit = align_range_limit

        super(NestedExecutor, self).__init__(
            time_per_step=time_per_step,
            start_time=start_time,
            end_time=end_time,
            indicator_config=indicator_config,
            generate_portfolio_metrics=generate_portfolio_metrics,
            verbose=verbose,
            track_data=track_data,
            common_infra=common_infra,
            **kwargs,
        )

    def reset_common_infra(self, common_infra, copy_trade_account=False):
        """
        reset infrastructure for trading
            - reset inner_strategyand inner_executor common infra
        """
        # NOTE: please refer to the docs of BaseExecutor.reset_common_infra for the meaning of `copy_trade_account`

        # The first level follow the `copy_trade_account` from the upper level
        super(NestedExecutor, self).reset_common_infra(common_infra, copy_trade_account=copy_trade_account)

        # The lower level have to copy the trade_account
        self.inner_executor.reset_common_infra(common_infra, copy_trade_account=True)
        self.inner_strategy.reset_common_infra(common_infra)

    def _init_sub_trading(self, trade_decision):
        trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
        self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
        sub_level_infra = self.inner_executor.get_level_infra()
        self.level_infra.set_sub_level_infra(sub_level_infra)
        self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)

    def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision:
        # outer strategy have chance to update decision each iterator
        updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)
        if updated_trade_decision is not None:
            trade_decision = updated_trade_decision
            # NEW UPDATE
            # create a hook for inner strategy to update outer decision
            self.inner_strategy.alter_outer_trade_decision(trade_decision)
        return trade_decision

    def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
        execute_result = []
        inner_order_indicators = []
        decision_list = []
        # NOTE:
        # - this is necessary to calculating the steps in sub level
        # - more detailed information will be set into trade decision
        self._init_sub_trading(trade_decision)

        _inner_execute_result = None
        while not self.inner_executor.finished():
            trade_decision = self._update_trade_decision(trade_decision)

            if trade_decision.empty() and self._skip_empty_decision:
                # give one chance for outer strategy to update the strategy
                # - For updating some information in the sub executor(the strategy have no knowledge of the inner
                # executor when generating the decision)
                break

            sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar

            # NOTE: make sure get_start_end_idx is after `self._update_trade_decision`
            start_idx, end_idx = get_start_end_idx(sub_cal, trade_decision)
            if not self._align_range_limit or start_idx <= sub_cal.get_trade_step() <= end_idx:
                # if force align the range limit, skip the steps outside the decision range limit

                res = self.inner_strategy.generate_trade_decision(_inner_execute_result)

                # NOTE: !!!!!
                # the two lines below is for a special case in RL
                # To solve the confliction below
                # - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction loop
                #   For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=> (inner Qlib Executor)])
                # - However, RL-based framework has it's own script to run the loop
                #   For an _RL learning example_, (RL Policy) <=> (RL Env[(inner Qlib Executor)])
                # To make it possible to run  _nested qlib example_ and _RL learning example_ together, the solution below is proposed
                # - The entry script follow the example of  _RL learning example_ to be compatible with all kinds of RL Framework
                # - Each step of (RL Env) will make (inner Qlib Executor) one step forward
                #     - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env) by `yield from` and wait for the action from the policy
                # So the two lines below is the implementation of yielding control rights
                if isinstance(res, GeneratorType):
                    res = yield from res

                _inner_trade_decision: BaseTradeDecision = res

                trade_decision.mod_inner_decision(_inner_trade_decision)  # propagate part of decision information

                # NOTE sub_cal.get_step_time() must be called before collect_data in case of step shifting
                decision_list.append((_inner_trade_decision, *sub_cal.get_step_time()))

                # NOTE: Trade Calendar will step forward in the follow line
                _inner_execute_result = yield from self.inner_executor.collect_data(
                    trade_decision=_inner_trade_decision, level=level + 1
                )
                self.post_inner_exe_step(_inner_execute_result)
                execute_result.extend(_inner_execute_result)

                inner_order_indicators.append(
                    self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True)
                )
            else:
                # do nothing and just step forward
                sub_cal.step()

        return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}

    def post_inner_exe_step(self, inner_exe_res):
        """
        A hook for doing sth after each step of inner strategy

        Parameters
        ----------
        inner_exe_res :
            the execution result of inner task
        """

    def get_all_executors(self):
        """get all executors, including self and inner_executor.get_all_executors()"""
        return [self, *self.inner_executor.get_all_executors()]

In [None]:
level_infra = LevelInfrastructure()
print(type(level_infra))

In [None]:
common_infra.get_support_infra()

In [None]:
get_support_infra = level_infra.get_support_infra()
get_support_infra

In [None]:
level_infra.reset_cal("day", "2021-05-20", "2021-06-01")

In [None]:
level_infra.set_sub_level_infra(sub_level_infra = LevelInfrastructure())

In [None]:
sl = model.predict(dataset)
from qlib.contrib.strategy.signal_strategy import create_signal_from
sgl = create_signal_from(sl)
tk = WeekTopkDropoutStrategy(signal = sgl, risk_degree = 0.95,common_infra=common_infra, topk = 50, 
                         n_drop = 5, method_sell="bottom", method_buy="top", level_infra = level_infra,
                         hold_thresh=1, only_tradable=False)

In [None]:
WO = tk.generate_trade_decision()

In [None]:
WO