In [1]:
import os
import pandas as pd
from pandas import DataFrame
import talib

from lightgbm import LGBMRanker

In [9]:
def cal_code_feature(df: DataFrame):
    # 计算RSI指标
    df['RSI2'] = talib.RSI(df['close'], timeperiod=2) / 100
    df['RSI5'] = talib.RSI(df['close'], timeperiod=5) / 100
    df['RSI10'] = talib.RSI(df['close'], timeperiod=10) / 100
def read_data(csv_dir: str):
    dataframes = {}
    for filename in os.listdir(csv_dir):
        if filename.endswith('.csv'):
            file_path = os.path.join(csv_dir, filename)
            try:
                df = pd.read_csv(file_path)
                df['ret5'] = (df['close'].shift(-5) - df['close']) / df['close']
                df = df[df["date"] < "2024-12-20"]
                df = df[df["date"] > "2016-02-01"]
                # 过滤ret5为NAN的数据
                df = df.dropna(subset=['ret5'])

                cal_code_feature(df)
                df = df.drop("Unnamed: 0", axis=1)

                dataframes[filename] = df
                print(f"成功读取文件: {filename}, 形状: {df.shape}")
            except Exception as e:
                print(f"读取文件 {filename} 时出错: {e}")
    # 将所有DataFrame进行union操作
    all_data = pd.concat(dataframes.values(), ignore_index=True)
    return all_data

In [10]:
csv_dir = '/Users/rui.chengcr/PycharmProjects/qstrader/qs_data/price/'
all_data = read_data(csv_dir)

成功读取文件: 600909.csv, 形状: (1948, 16)
成功读取文件: 600048.csv, 形状: (2151, 16)
成功读取文件: 600060.csv, 形状: (2157, 16)
成功读取文件: 300724.csv, 形状: (1543, 16)
成功读取文件: 688002.csv, 形状: (1315, 16)
成功读取文件: 601168.csv, 形状: (1965, 16)
成功读取文件: 600276.csv, 形状: (2157, 16)
成功读取文件: 603019.csv, 形状: (2155, 16)
成功读取文件: 000987.csv, 形状: (2049, 16)
成功读取文件: 000039.csv, 形状: (2153, 16)
成功读取文件: 300122.csv, 形状: (2158, 16)
成功读取文件: 300136.csv, 形状: (2158, 16)
成功读取文件: 000830.csv, 形状: (2148, 16)
成功读取文件: 600316.csv, 形状: (2158, 16)
成功读取文件: 600699.csv, 形状: (2133, 16)
成功读取文件: 000401.csv, 形状: (2018, 16)
成功读取文件: 601236.csv, 形状: (1320, 16)
成功读取文件: 600115.csv, 形状: (2157, 16)
成功读取文件: 600673.csv, 形状: (1884, 16)
成功读取文件: 002203.csv, 形状: (2147, 16)
成功读取文件: 601009.csv, 形状: (2156, 16)
成功读取文件: 002001.csv, 形状: (2154, 16)
成功读取文件: 601021.csv, 形状: (2151, 16)
成功读取文件: 688188.csv, 形状: (1302, 16)
成功读取文件: 000825.csv, 形状: (2050, 16)
成功读取文件: 688981.csv, 形状: (1076, 16)
成功读取文件: 601169.csv, 形状: (2150, 16)
成功读取文件: 600511.csv, 形状: (2037, 16)
成功读取文件: 601633.csv, 

In [11]:
def cal_group_feature(df: DataFrame):
    # 将code列的数据类型转换为int
    df['code'] = df['code'].astype(int)

    # 基于date分组，计算ret5的rank百分比
    df['ret5_rank'] = df.groupby('date')['ret5'].rank(pct=True)
    # 基于date分组，计算RSI2的rank百分比
    df['RSI2_rank'] = df.groupby('date')['RSI2'].rank(pct=True)
    # 基于date分组，计算RSI5的rank百分比
    df['RSI5_rank'] = df.groupby('date')['RSI5'].rank(pct=True)
    # 基于date分组，计算RSI10的rank百分比
    df['RSI10_rank'] = df.groupby('date')['RSI10'].rank(pct=True)

In [12]:
cal_group_feature(all_data)

In [13]:
def cal_label(df: DataFrame):

    # 基于date分组，计算ret5的rank，从0开始
    df['label'] = df.groupby('date')['ret5'].rank(pct=True)
    df['label'] = (df['label'] * 16).astype(int)

    # 基于date列对DataFrame进行排序
    df = df.sort_values('date')
    return df

In [14]:
all_data = cal_label(all_data)

In [15]:
# 对所有浮点数列保留4位小数
float_columns = all_data.select_dtypes(include=['float64', 'float32']).columns
all_data[float_columns] = all_data[float_columns].round(4)

In [16]:
def train_model(df: DataFrame):
    print("start to train model")
    model = LGBMRanker(
        objective="lambdarank",
        metric="ndcg",        
        n_estimators=100,
        learning_rate=0.05,
        max_depth=5
    )
    features = ['RSI2_rank', 'RSI5_rank', 'RSI10_rank', 'ret5_rank', 'RSI2', 'RSI5', 'RSI10']
    model.fit(df[features], df['label'], group=df.groupby('date').size().values)
    return model

In [38]:
train_data = all_data[all_data['date'] <= '2021']

In [39]:
model = train_model(train_data)

start to train model
[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.001242 seconds.
You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 1785
[LightGBM] [Info] Number of data points in the train set: 464807, number of used features: 7


In [44]:
def predict_model(model, df: DataFrame):
    predict_df = df
    features = ['RSI2_rank', 'RSI5_rank', 'RSI10_rank', 'ret5_rank', 'RSI2', 'RSI5', 'RSI10']
    predict_df['prediction'] = model.predict(predict_df[features])
    return predict_df

In [45]:
test_data = all_data[all_data['date'] >= '2021-01-01'].copy()

In [46]:
predict_data = predict_model(model, test_data)

In [49]:
top_predictions = predict_data.groupby('date').apply(
        lambda x: x.nlargest(30, 'prediction')
    ).reset_index(drop=True)

  top_predictions = predict_data.groupby('date').apply(


In [51]:
top_predictions[top_predictions["date"]== '2024-01-04']

Unnamed: 0,date,code,open,close,high,low,adjust_open,adjust_close,adjust_high,adjust_low,...,ret5,RSI2,RSI5,RSI10,ret5_rank,RSI2_rank,RSI5_rank,RSI10_rank,label,prediction
21870,2024-01-04,2466,55.12,53.66,55.12,53.35,472.97,460.53,472.97,457.89,...,0.1148,0.098,0.3911,0.4799,0.9978,0.1499,0.3803,0.5168,15,4.673849
21871,2024-01-04,300896,285.0,276.35,286.0,273.01,528.65,513.08,530.45,507.06,...,0.1598,0.1493,0.4084,0.4388,1.0,0.2953,0.396,0.4295,16,4.494989
21872,2024-01-04,603259,69.98,69.16,69.98,68.64,169.67,167.74,169.67,166.52,...,0.0374,0.0739,0.2389,0.2869,0.9597,0.0828,0.0671,0.0559,15,2.674231
21873,2024-01-04,2352,38.71,38.25,38.75,37.91,124.07,122.69,124.19,121.67,...,0.0314,0.0957,0.2595,0.3254,0.9485,0.1432,0.0917,0.1432,15,2.349835
21874,2024-01-04,300017,7.93,8.08,8.22,7.75,186.64,189.99,193.13,182.61,...,0.0446,0.9818,0.8085,0.6976,0.9821,0.9642,0.906,0.8971,15,2.336006
21875,2024-01-04,600529,26.13,26.22,26.42,25.87,150.96,151.43,152.49,149.58,...,0.0351,0.9702,0.7838,0.6285,0.9553,0.9396,0.8702,0.8121,15,2.336006
21876,2024-01-04,601717,12.86,13.09,13.12,12.76,30.5,30.96,31.02,30.3,...,0.0283,0.9336,0.8988,0.7847,0.9441,0.8792,0.9799,0.9821,15,2.336006
21877,2024-01-04,786,24.74,25.1,25.29,24.5,341.4,345.78,348.1,338.48,...,0.0279,0.9767,0.8527,0.7277,0.9418,0.9553,0.9553,0.9396,15,2.336006
21878,2024-01-04,600563,86.0,86.1,86.92,82.0,152.2,152.35,153.58,146.2,...,0.0476,0.2043,0.4197,0.4405,0.9843,0.4049,0.4206,0.4362,15,1.960662
21879,2024-01-04,300274,84.47,82.66,84.95,82.47,553.14,541.41,556.25,540.18,...,0.0513,0.1248,0.4165,0.4887,0.9888,0.2237,0.4161,0.5369,15,1.960662
