<a href="https://colab.research.google.com/github/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#  Copyright (c) Microsoft Corporation.
#  Licensed under the MIT License.

In [1]:
import sys, site
from pathlib import Path

################################# NOTE #################################
#  Please be aware that if colab installs the latest numpy and pyqlib  #
#  in this cell, users should RESTART the runtime in order to run the  #
#  following cells successfully.                                       #
########################################################################

try:
    import qlib
except ImportError:
    # install qlib
    ! pip install --upgrade numpy
    ! pip install pyqlib
    if "google.colab" in sys.modules:
        # The Google colab environment is a little outdated. We have to downgrade the pyyaml to make it compatible with other packages
        ! pip install pyyaml==5.4.1
    # reload
    site.main()

scripts_dir = Path.cwd().parent.joinpath("scripts")
if not scripts_dir.joinpath("get_data.py").exists():
    # download get_data.py script
    scripts_dir = Path("~/tmp/qlib_code/scripts").expanduser().resolve()
    scripts_dir.mkdir(parents=True, exist_ok=True)
    import requests

    with requests.get("https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py", timeout=10) as resp:
        with open(scripts_dir.joinpath("get_data.py"), "wb") as fp:
            fp.write(resp.content)

In [2]:
import qlib
import pandas as pd
from qlib.constant import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.utils import flatten_dict

In [3]:
provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
qlib.init(provider_uri=provider_uri, region=REG_CN)

[3760863:MainThread](2025-07-31 19:55:44,632) INFO - qlib.Initialization - [config.py:420] - default_conf: client.
[3760863:MainThread](2025-07-31 19:55:44,639) INFO - qlib.Initialization - [__init__.py:75] - qlib successfully initialized based on client settings.
[3760863:MainThread](2025-07-31 19:55:44,641) INFO - qlib.Initialization - [__init__.py:77] - data_path={'__DEFAULT_FREQ': PosixPath('/DATA/hxy/qlib/cn_data')}


In [4]:
def coverage_metric(y_pred, data):
    
    y_true = data.get_label()
    coverage = (y_true <= y_pred).astype(float).mean()
    return 'coverage', coverage, True

# train model

"model": {
    "class": "LGBModel",
    "module_path": "qlib.contrib.model.gbdt",
    "kwargs": {
        "loss": "mse",
        # "alpha": 0.5,
        "colsample_bytree": 0.8879,
        "learning_rate": 0.0421,
        "subsample": 0.8789,
        "lambda_l1": 205.6999,
        "lambda_l2": 580.9768,
        "max_depth": 8,
        "num_leaves": 210,
        "num_threads": 20,
    },
},

In [8]:
market = "csiall"
benchmark = "SH000905"
# 学习的参数
learn_processors = [
    {"class": "DropnaLabel"},
    {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier":True}}, # 有状态，标准化特征
    {"class":"CSZScoreNorm", "kwargs": {"fields_group": "label"}} # 无状态，按天横截面标准化
]
# 测试集
infer_processors = [
    {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier":True}},  # 使用训练集参数
    {"class":"CSZScoreNorm", "kwargs": {"fields_group": "label"}} # 无状态，也要加
]
shared_processors = [
    {"class": "ProcessInf", "kwargs": {}},
    {"class": "Fillna", "kwargs":{"fields_group": "feature"}}
]

###################################
# train model
###################################
data_handler_config = {
    "start_time": "2016-12-31",
    "end_time": "2022-12-31",
    "fit_start_time": "2016-12-31",
    "fit_end_time": "2020-12-31",
    "instruments": market,
    "infer_processors":infer_processors,
    "learn_processors":learn_processors,
}

task = {
    "model": {
        "class": "GRU",
        "module_path": "qlib.contrib.model.pytorch_gru_ts",
        "kwargs": {
            "d_feat": 158,
            "hidden_size": 64,
            "num_layers": 2,
            "dropout": 0.0,
            "n_epochs": 40,
            "lr": 2e-4,
            "early_stop": 10,
            "batch_size": 3000,
            "metric": "loss",
            "loss": "mse",
            "n_jobs": 20,
            "GPU": 0,
        },
    },
    "dataset": {
        "class": "TSDatasetH",
        "module_path": "qlib.data.dataset",
        "kwargs": {
            "handler": {
                "class": "Alpha158",
                "module_path": "qlib.contrib.data.handler",
                "kwargs": data_handler_config,
            },
            "segments": {
                "train": ("2016-12-31", "2020-12-31"),
                "valid": ("2021-01-01", "2021-12-31"),
                "test": ("2022-01-01", "2022-12-31"),
            },
            
        },
    },
}

# model initialization
model = init_instance_by_config(task["model"])


[3760863:MainThread](2025-07-31 20:04:18,512) INFO - qlib.GRU - [pytorch_gru_ts.py:61] - GRU pytorch version...
[3760863:MainThread](2025-07-31 20:04:18,516) INFO - qlib.GRU - [pytorch_gru_ts.py:79] - GRU parameters setting:
d_feat : 158
hidden_size : 64
num_layers : 2
dropout : 0.0
n_epochs : 40
lr : 0.0002
metric : loss
batch_size : 3000
early_stop : 10
optimizer : adam
loss_type : mse
device : cuda:0
n_jobs : 20
use_GPU : True
seed : None
[3760863:MainThread](2025-07-31 20:04:18,520) INFO - qlib.GRU - [pytorch_gru_ts.py:124] - model:
GRUModel(
  (rnn): GRU(158, 64, num_layers=2, batch_first=True)
  (fc_out): Linear(in_features=64, out_features=1, bias=True)
)
[3760863:MainThread](2025-07-31 20:04:18,522) INFO - qlib.GRU - [pytorch_gru_ts.py:125] - model size: 0.0649 MB


In [9]:
dataset = init_instance_by_config(task["dataset"])


[3760863:MainThread](2025-07-31 20:06:23,329) INFO - qlib.timer - [log.py:127] - Time cost: 83.636s | Loading data Done
[3760863:MainThread](2025-07-31 20:07:10,977) INFO - qlib.timer - [log.py:127] - Time cost: 40.234s | RobustZScoreNorm Done
[3760863:MainThread](2025-07-31 20:07:15,744) INFO - qlib.timer - [log.py:127] - Time cost: 4.758s | CSZScoreNorm Done
[3760863:MainThread](2025-07-31 20:07:23,739) INFO - qlib.timer - [log.py:127] - Time cost: 2.865s | DropnaLabel Done
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[self.cols] = X
[3760863:MainThread](2025-07-31 20:07:57,345) INFO - qlib.timer - [log.py:127] - Time cost: 33.602s | RobustZScoreNorm Done
[3760863:MainThread](2025-07-31 20:08:02,688) INFO - qlib.timer - [log.py:127] - Time cost: 5.340s | CSZScoreNor

In [11]:
# start exp to train model
with R.start(experiment_name="train_model"):
    R.log_params(**flatten_dict(task))
    model.fit(dataset, 
            #   feval=coverage_metric,
              )
    R.save_objects(trained_model=model)
    rid = R.get_recorder().id

[3760863:MainThread](2025-07-31 20:19:25,453) INFO - qlib.workflow - [exp.py:258] - Experiment 377830769695977458 starts running ...
[3760863:MainThread](2025-07-31 20:19:25,560) INFO - qlib.workflow - [recorder.py:345] - Recorder e4c2dd9e8e0748d9ab94b3fbac523c5e starts running under Experiment 377830769695977458 ...
[3760863:MainThread](2025-07-31 20:20:00,848) INFO - qlib.GRU - [pytorch_gru_ts.py:249] - training...
[3760863:MainThread](2025-07-31 20:20:00,852) INFO - qlib.GRU - [pytorch_gru_ts.py:253] - Epoch0:
[3760863:MainThread](2025-07-31 20:20:00,854) INFO - qlib.GRU - [pytorch_gru_ts.py:254] - training...
[3760863:MainThread](2025-07-31 20:21:27,106) INFO - qlib.GRU - [pytorch_gru_ts.py:256] - evaluating...
[3760863:MainThread](2025-07-31 20:23:25,079) INFO - qlib.GRU - [pytorch_gru_ts.py:259] - train nan, valid nan
[3760863:MainThread](2025-07-31 20:23:25,084) INFO - qlib.GRU - [pytorch_gru_ts.py:253] - Epoch1:
[3760863:MainThread](2025-07-31 20:23:25,087) INFO - qlib.GRU - [p

KeyboardInterrupt: 

In [None]:
model.predict(dataset)

# prediction, backtest & analysis

In [None]:
###################################
# prediction, backtest & analysis
###################################
port_analysis_config = {
    "executor": {
        "class": "SimulatorExecutor",
        "module_path": "qlib.backtest.executor",
        "kwargs": {
            "time_per_step": "day",
            "generate_portfolio_metrics": True,
        },
    },
    "strategy": {
        "class": "TopkDropoutStrategy",
        "module_path": "qlib.contrib.strategy.signal_strategy",
        "kwargs": {
            "model": model,
            "dataset": dataset,
            "topk": 50,
            "n_drop": 5,
        },
    },
    "backtest": {
        "start_time": "2017-01-01",
        "end_time": "2020-08-01",
        "account": 100000000,
        "benchmark": benchmark,
        "exchange_kwargs": {
            "freq": "day",
            "limit_threshold": 0.095,
            "deal_price": "close",
            "open_cost": 0.0005,
            "close_cost": 0.0015,
            "min_cost": 5,
        },
    },
}

# backtest and analysis
with R.start(experiment_name="backtest_analysis"):
    recorder = R.get_recorder(recorder_id=rid, experiment_name="train_model")
    model = recorder.load_object("trained_model")

    # prediction
    recorder = R.get_recorder()
    ba_rid = recorder.id
    sr = SignalRecord(model, dataset, recorder)
    sr.generate()

    # backtest & analysis
    par = PortAnaRecord(recorder, port_analysis_config, "day")
    par.generate()

# analyze graphs

In [None]:
from qlib.contrib.report import analysis_model, analysis_position
from qlib.data import D

recorder = R.get_recorder(recorder_id=ba_rid, experiment_name="backtest_analysis")
print(recorder)
pred_df = recorder.load_object("pred.pkl")
report_normal_df = recorder.load_object("portfolio_analysis/report_normal_1day.pkl")
positions = recorder.load_object("portfolio_analysis/positions_normal_1day.pkl")
analysis_df = recorder.load_object("portfolio_analysis/port_analysis_1day.pkl")

## analysis position

### report

In [None]:
analysis_position.report_graph(report_normal_df)

### risk analysis

In [None]:
analysis_position.risk_analysis_graph(analysis_df, report_normal_df)

## analysis model

In [None]:
label_df = dataset.prepare("test", col_set="label")
label_df.columns = ["label"]

### score IC

In [None]:
pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)
analysis_position.score_ic_graph(pred_label)

### model performance

In [None]:
analysis_model.model_performance_graph(pred_label)