In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.metrics import classification_report
from imblearn.over_sampling import SMOTE
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import LinearSVC  
from sklearn.neural_network import MLPClassifier
from xgboost import XGBClassifier
import time
from datetime import datetime

import warnings
warnings.filterwarnings("ignore", category=UserWarning)


In [2]:
# 中证A50成分股
# 信息来源：东方财富 20250324
code_list = [
    '688981sh', '603993sh', '603259sh', '601899sh', '601888sh', 
    '601816sh', '601766sh', '601668sh', '601600sh', '601318sh', 
    '601088sh', '601012sh', '600900sh', '600893sh',     # 中国移动2022年上市，暂时没有相关数据
    '600887sh', '600660sh', '600585sh', '600519sh', '600436sh', 
    '600426sh', '600415sh', '600406sh', '600309sh', '600276sh', 
    '600176sh', '600036sh', '600031sh', '600030sh', '600028sh', 
    '600019sh', '600009sh', '300760sz', '300750sz', '300408sz', 
    '300124sz', '300122sz', '300015sz', '002714sz', '002594sz', 
    '002475sz', '002371sz', '002230sz', '002027sz', '000938sz', 
    '000792sz', '000725sz', '000333sz', '000063sz', '000002sz'
    ]  

REC_CNT = 10
PRED_CNT = 5
IS_BINARY = True

In [None]:
def calc_new_features(df: pd.DataFrame) -> pd.DataFrame:
    """
    计算新特征
    """
    df['mid_price'] = (df['BidPr1'] + df['AskPr1']) / 2
    df['TWAP_mid_price'] = df['mid_price'].rolling(window=PRED_CNT).mean()
    
    df.dropna(axis=0, inplace=True)
    df.reset_index(drop=True, inplace=True)
    return df

def calc_labels(df: pd.DataFrame) -> pd.DataFrame:
    """
    计算标签
    """
    df['label'] = 0
    if IS_BINARY:
        df['label'] = np.where(df['TWAP_mid_price'].shift(-PRED_CNT) > df['mid_price'], 1, 0)
    # TODO: else:
    return df

def single_entry_gen(df: pd.DataFrame):
    """
    生成单条数据样本
    """
    snapshots_df = df.copy().reset_index(drop=True)
    base_price = snapshots_df['mid_price'].iloc[0]
    for level in range(1, 6):  # relative price
        snapshots_df[f'BidPr{level}'] = snapshots_df[f'BidPr{level}'] / base_price - 1
        snapshots_df[f'AskPr{level}'] = snapshots_df[f'AskPr{level}'] / base_price - 1
    # TODO: 其他特征
    
    used_cols = []
    for level in range(1, 6):
        used_cols.append(f'BidPr{level}')
        used_cols.append(f'BidVol{level}')
        used_cols.append(f'AskPr{level}')
        used_cols.append(f'AskVol{level}')
    
    X = snapshots_df.loc[:REC_CNT, used_cols].copy()  # rolling window
    X = X.values.flatten()
    # TODO：X 增加一些特征，比如时间、股票代码等
    y = snapshots_df['label'].iloc[REC_CNT-1].copy()
    
    return X, y

In [None]:
# train_df

folder_path = '../data_202111'

X_df_list = []
y_df_list = []
for code in code_list:
    # df = _U.get_stock_data(code, start_date='2021-11-1', end_date='2021-11-2')
    df = pd.read_csv(f'{folder_path}/{code}.csv')
    df['datetime'] = pd.to_datetime(df['datetime'])
    df = df[(df['datetime'] >= '2021-11-1 00:00:00') & (df['datetime'] <= '2021-11-2 23:59:59')]  # 初步只取
    
    df = calc_new_features(df)
    df = calc_labels(df)
    X_list = []
    y_list = []
    for i in tqdm.tqdm(range(len(df) - (REC_CNT + PRED_CNT))):
        single_X, single_y = single_entry_gen(df.iloc[i:i + REC_CNT])
        X_list.append(single_X)
        y_list.append(single_y)
    
    X_df = pd.DataFrame(X_list)
    y_df = pd.DataFrame(y_list)
    
    used_cols = []
    for level in range(1, 6):
        used_cols.append(f'BidPr{level}')
        used_cols.append(f'BidVol{level}')
        used_cols.append(f'AskPr{level}')
        used_cols.append(f'AskVol{level}')
    col_list = []
    for i in range(1, REC_CNT):
        for col in used_cols:
            col_list.append(f'{col}_lag{i}')
    # print(f'X cols: {used_cols + col_list}')
    if len(X_df) == 0:
        print(f'{code}.csv empty.')
        continue

    if len(X_df.columns) != len(used_cols + col_list):
        raise ValueError(f'X_df columns length mismatch: {len(X_df.columns)} != {len(used_cols + col_list)}')
    X_df.columns = used_cols + col_list
    y_df.columns = ['label']
    
    X_df_list.append(X_df)
    y_df_list.append(y_df)
    print(f'{code}.csv loaded.')
    
train_X_df = pd.concat(X_df_list, axis=0)
train_y_df = pd.concat(y_df_list, axis=0)
train_X_df.dropna(axis=0, inplace=True)
train_y_df.dropna(axis=0, inplace=True)
train_X_df.reset_index(drop=True, inplace=True)
train_y_df.reset_index(drop=True, inplace=True)
train_X_df.shape, train_y_df.shape, train_y_df['label'].value_counts(normalize=True)

100%|██████████| 9482/9482 [00:21<00:00, 445.21it/s]


688981sh.csv loaded.


100%|██████████| 9471/9471 [00:20<00:00, 456.60it/s]


603993sh.csv loaded.


100%|██████████| 9460/9460 [00:20<00:00, 456.50it/s]


603259sh.csv loaded.


100%|██████████| 9460/9460 [00:20<00:00, 454.90it/s]


601899sh.csv loaded.


100%|██████████| 4726/4726 [00:10<00:00, 457.86it/s]


601888sh.csv loaded.


100%|██████████| 9502/9502 [00:20<00:00, 456.31it/s]


601816sh.csv loaded.


100%|██████████| 9500/9500 [00:20<00:00, 456.68it/s]


601766sh.csv loaded.


100%|██████████| 9475/9475 [00:20<00:00, 457.21it/s]


601668sh.csv loaded.


100%|██████████| 9460/9460 [00:20<00:00, 456.23it/s]


601600sh.csv loaded.


100%|██████████| 9461/9461 [00:20<00:00, 456.70it/s]


601318sh.csv loaded.


100%|██████████| 9468/9468 [00:20<00:00, 456.85it/s]


601088sh.csv loaded.


100%|██████████| 9461/9461 [00:20<00:00, 457.05it/s]


601012sh.csv loaded.


100%|██████████| 9473/9473 [00:20<00:00, 458.32it/s]


600900sh.csv loaded.


100%|██████████| 9465/9465 [00:20<00:00, 456.59it/s]


600893sh.csv loaded.


100%|██████████| 9462/9462 [00:20<00:00, 457.85it/s]


600887sh.csv loaded.


100%|██████████| 9465/9465 [00:20<00:00, 455.87it/s]


600660sh.csv loaded.


100%|██████████| 9470/9470 [00:20<00:00, 457.13it/s]


600585sh.csv loaded.


100%|██████████| 9499/9499 [00:20<00:00, 452.63it/s]


600519sh.csv loaded.


100%|██████████| 9469/9469 [00:20<00:00, 456.32it/s]


600436sh.csv loaded.


100%|██████████| 9466/9466 [00:20<00:00, 451.81it/s]


600426sh.csv loaded.


100%|██████████| 9484/9484 [00:20<00:00, 454.78it/s]


600415sh.csv loaded.


100%|██████████| 9473/9473 [00:20<00:00, 456.80it/s]


600406sh.csv loaded.


100%|██████████| 9461/9461 [00:21<00:00, 436.90it/s]


600309sh.csv loaded.


100%|██████████| 9461/9461 [00:21<00:00, 435.12it/s]


600276sh.csv loaded.


100%|██████████| 9460/9460 [00:21<00:00, 443.93it/s]


600176sh.csv loaded.


100%|██████████| 9460/9460 [00:21<00:00, 448.02it/s]


600036sh.csv loaded.


100%|██████████| 9460/9460 [00:20<00:00, 459.11it/s]


600031sh.csv loaded.


100%|██████████| 9461/9461 [00:21<00:00, 446.88it/s]


600030sh.csv loaded.


100%|██████████| 9470/9470 [00:22<00:00, 426.08it/s]


600028sh.csv loaded.


100%|██████████| 9462/9462 [00:22<00:00, 424.76it/s]


600019sh.csv loaded.


100%|██████████| 9486/9486 [00:21<00:00, 435.81it/s]


600009sh.csv loaded.


100%|██████████| 9446/9446 [00:21<00:00, 429.53it/s]


300760sz.csv loaded.


100%|██████████| 9457/9457 [00:23<00:00, 397.46it/s]


300750sz.csv loaded.


100%|██████████| 9360/9360 [00:22<00:00, 424.76it/s]


300408sz.csv loaded.


100%|██████████| 9457/9457 [00:21<00:00, 440.26it/s]


300124sz.csv loaded.


100%|██████████| 9457/9457 [00:20<00:00, 455.35it/s]


300122sz.csv loaded.


100%|██████████| 9457/9457 [00:21<00:00, 445.46it/s]


300015sz.csv loaded.


100%|██████████| 9457/9457 [00:21<00:00, 445.48it/s]


002714sz.csv loaded.


100%|██████████| 9457/9457 [00:21<00:00, 443.45it/s]


002594sz.csv loaded.


100%|██████████| 9457/9457 [00:20<00:00, 454.07it/s]


002475sz.csv loaded.


100%|██████████| 9442/9442 [00:21<00:00, 436.33it/s]


002371sz.csv loaded.


100%|██████████| 9454/9454 [00:22<00:00, 413.78it/s]


002230sz.csv loaded.


100%|██████████| 9457/9457 [00:22<00:00, 414.74it/s]


002027sz.csv loaded.


100%|██████████| 9452/9452 [00:22<00:00, 426.84it/s]


000938sz.csv loaded.


100%|██████████| 9457/9457 [00:21<00:00, 444.32it/s]


000792sz.csv loaded.


100%|██████████| 9457/9457 [00:20<00:00, 457.20it/s]


000725sz.csv loaded.


100%|██████████| 9457/9457 [00:20<00:00, 459.06it/s]


000333sz.csv loaded.


100%|██████████| 9457/9457 [00:20<00:00, 455.33it/s]


000063sz.csv loaded.


100%|██████████| 9457/9457 [00:21<00:00, 438.97it/s]


000002sz.csv loaded.


((458928, 200),
 (458928, 1),
 label
 0    0.633958
 1    0.366042
 Name: proportion, dtype: float64)

In [None]:
# test_df

folder_path = '../data_202111'

X_df_list = []
y_df_list = []
for code in code_list:
    # df = _U.get_stock_data(code, start_date='2021-11-3', end_date='2021-11-5')
    df = pd.read_csv(f'{folder_path}/{code}.csv')
    df['datetime'] = pd.to_datetime(df['datetime'])
    df = df[(df['datetime'] >= '2021-11-3 00:00:00') & (df['datetime'] <= '2021-11-5 23:59:59')]  
    
    # generate new features, labels
    # rolling window for each entry
    df = calc_new_features(df)
    df = calc_labels(df)
    X_list = []
    y_list = []
    for i in tqdm.tqdm(range(len(df) - (REC_CNT + PRED_CNT))):
        single_X, single_y = single_entry_gen(df.iloc[i:i + REC_CNT])
        X_list.append(single_X)
        y_list.append(single_y)
    X_df = pd.DataFrame(X_list)
    y_df = pd.DataFrame(y_list)
    
    # rename columns for X_df, y_df
    used_cols = []
    for level in range(1, 6):
        used_cols.append(f'BidPr{level}')
        used_cols.append(f'BidVol{level}')
        used_cols.append(f'AskPr{level}')
        used_cols.append(f'AskVol{level}')
    col_list = []
    for i in range(1, REC_CNT):
        for col in used_cols:
            col_list.append(f'{col}_lag{i}')
    # print(f'X cols: {used_cols + col_list}')
    if len(X_df) == 0:
        print(f'{code}.csv empty.')
        continue
    if len(X_df.columns) != len(used_cols + col_list):
        raise ValueError(f'X_df columns length mismatch: {len(X_df.columns)} != {len(used_cols + col_list)}')
    X_df.columns = used_cols + col_list
    y_df.columns = ['label']
    
    X_df_list.append(X_df)
    y_df_list.append(y_df)
    print(f'{code}.csv loaded.')
    
test_X_df = pd.concat(X_df_list, axis=0)
test_y_df = pd.concat(y_df_list, axis=0)
test_X_df.dropna(axis=0, inplace=True)
test_y_df.dropna(axis=0, inplace=True)
test_X_df.reset_index(drop=True, inplace=True)
test_y_df.reset_index(drop=True, inplace=True)
test_X_df.shape, test_y_df.shape, test_y_df['label'].value_counts(normalize=True)

100%|██████████| 14239/14239 [00:32<00:00, 444.23it/s]


688981sh.csv loaded.


100%|██████████| 14210/14210 [00:34<00:00, 416.97it/s]


603993sh.csv loaded.


100%|██████████| 14207/14207 [00:31<00:00, 449.70it/s]


603259sh.csv loaded.


100%|██████████| 14201/14201 [00:32<00:00, 440.17it/s]


601899sh.csv loaded.


100%|██████████| 14212/14212 [00:30<00:00, 460.29it/s]


601888sh.csv loaded.


100%|██████████| 14139/14139 [00:30<00:00, 457.37it/s]


601816sh.csv loaded.


100%|██████████| 14257/14257 [00:31<00:00, 458.23it/s]


601766sh.csv loaded.


100%|██████████| 14248/14248 [00:31<00:00, 445.60it/s]


601668sh.csv loaded.


100%|██████████| 14203/14203 [00:42<00:00, 330.90it/s]


601600sh.csv loaded.


100%|██████████| 14204/14204 [00:34<00:00, 413.16it/s]


601318sh.csv loaded.


100%|██████████| 14226/14226 [00:34<00:00, 416.21it/s]


601088sh.csv loaded.


100%|██████████| 14201/14201 [00:34<00:00, 412.33it/s]


601012sh.csv loaded.


100%|██████████| 14239/14239 [00:34<00:00, 416.46it/s]


600900sh.csv loaded.


100%|██████████| 14209/14209 [00:34<00:00, 417.55it/s]


600893sh.csv loaded.


100%|██████████| 14201/14201 [00:32<00:00, 436.94it/s]


600887sh.csv loaded.


100%|██████████| 14218/14218 [00:32<00:00, 434.04it/s]


600660sh.csv loaded.


100%|██████████| 14265/14265 [00:33<00:00, 423.35it/s]


600585sh.csv loaded.


100%|██████████| 14256/14256 [00:33<00:00, 424.25it/s]


600519sh.csv loaded.


100%|██████████| 14237/14237 [00:33<00:00, 427.91it/s]


600436sh.csv loaded.


100%|██████████| 14233/14233 [00:32<00:00, 431.53it/s]


600426sh.csv loaded.


100%|██████████| 14183/14183 [00:33<00:00, 427.42it/s]


600415sh.csv loaded.


100%|██████████| 14243/14243 [00:32<00:00, 436.57it/s]


600406sh.csv loaded.


100%|██████████| 14208/14208 [00:33<00:00, 423.24it/s]


600309sh.csv loaded.


100%|██████████| 14202/14202 [00:31<00:00, 448.59it/s]


600276sh.csv loaded.


100%|██████████| 14231/14231 [00:33<00:00, 428.84it/s]


600176sh.csv loaded.


100%|██████████| 14207/14207 [00:32<00:00, 442.94it/s]


600036sh.csv loaded.


100%|██████████| 14200/14200 [00:31<00:00, 457.52it/s]


600031sh.csv loaded.


100%|██████████| 14207/14207 [00:33<00:00, 429.15it/s]


600030sh.csv loaded.


100%|██████████| 14232/14232 [00:32<00:00, 442.56it/s]


600028sh.csv loaded.


100%|██████████| 14226/14226 [00:31<00:00, 446.24it/s]


600019sh.csv loaded.


100%|██████████| 14226/14226 [00:31<00:00, 450.64it/s]


600009sh.csv loaded.


100%|██████████| 14174/14174 [00:31<00:00, 453.17it/s]


300760sz.csv loaded.


100%|██████████| 14195/14195 [00:30<00:00, 460.25it/s]


300750sz.csv loaded.


100%|██████████| 13972/13972 [00:30<00:00, 458.34it/s]


300408sz.csv loaded.


100%|██████████| 14195/14195 [00:31<00:00, 456.76it/s]


300124sz.csv loaded.


100%|██████████| 14190/14190 [00:32<00:00, 441.10it/s]


300122sz.csv loaded.


100%|██████████| 14195/14195 [00:30<00:00, 458.94it/s]


300015sz.csv loaded.


100%|██████████| 14195/14195 [00:31<00:00, 457.27it/s]


002714sz.csv loaded.


100%|██████████| 14195/14195 [00:31<00:00, 454.07it/s]


002594sz.csv loaded.


100%|██████████| 14195/14195 [00:30<00:00, 457.97it/s]


002475sz.csv loaded.


100%|██████████| 14157/14157 [00:32<00:00, 435.61it/s]


002371sz.csv loaded.


100%|██████████| 14194/14194 [00:31<00:00, 446.49it/s]


002230sz.csv loaded.


100%|██████████| 14195/14195 [00:31<00:00, 448.14it/s]


002027sz.csv loaded.


100%|██████████| 14146/14146 [00:31<00:00, 448.64it/s]


000938sz.csv loaded.


100%|██████████| 14193/14193 [00:30<00:00, 461.54it/s]


000792sz.csv loaded.


100%|██████████| 14195/14195 [00:31<00:00, 457.34it/s]


000725sz.csv loaded.


100%|██████████| 14195/14195 [00:31<00:00, 451.64it/s]


000333sz.csv loaded.


100%|██████████| 14195/14195 [00:32<00:00, 438.15it/s]


000063sz.csv loaded.


100%|██████████| 14195/14195 [00:33<00:00, 423.75it/s]


000002sz.csv loaded.


((695941, 200),
 (695941, 1),
 label
 0    0.644467
 1    0.355533
 Name: proportion, dtype: float64)

In [None]:
# save data to csv
train_X_df.to_csv('train_X.csv', index=False)
train_y_df.to_csv('train_y.csv', index=False)
test_X_df.to_csv('test_X.csv', index=False)
test_y_df.to_csv('test_y.csv', index=False)
# 2d-train, 3d-test, saving
# total time cost: 18m + 27.5m + 3.3m

In [7]:
# # load data from csv
# train_X_df = pd.read_csv('train_X.csv')
# train_y_df = pd.read_csv('train_y.csv')
# test_X_df = pd.read_csv('test_X.csv')
# test_y_df = pd.read_csv('test_y.csv')

In [16]:
train_X_df.shape, train_y_df.shape, test_X_df.shape, test_y_df.shape

((458928, 200), (458928, 1), (695941, 200), (695941, 1))

In [17]:
# 标准化

def create_scalers(train_df, prefixes):
    """为多个特征前缀创建标准化器"""
    scalers = {}
    for prefix in prefixes:
        scalers[prefix] = {}
        for level in range(1, 6):  # 假设有5档
            cols = [f"{prefix}{level}"] + [f"{prefix}{level}_lag{i}" for i in range(1, 10)]
            scaler = StandardScaler().fit(train_df[cols])
            scalers[prefix][level] = scaler
    return scalers

def apply_full_scaling(df, scalers):
    """应用标准化并合并所有特征"""
    scaled_dfs = []
    for prefix in scalers.keys():
        for level in scalers[prefix].keys():
            cols = [f"{prefix}{level}"] + [f"{prefix}{level}_lag{i}" for i in range(1, 10)]
            scaled_data = scalers[prefix][level].transform(df[cols])
            scaled_df = pd.DataFrame(scaled_data, columns=cols, index=df.index)
            scaled_dfs.append(scaled_df)
    return pd.concat(scaled_dfs, axis=1)

# BidPr/BidVol/AskPr/AskVol统一处理
all_prefixes = ['BidPr', 'BidVol', 'AskPr', 'AskVol']
scalers = create_scalers(train_X_df, all_prefixes)
train_X_scaled = apply_full_scaling(train_X_df, scalers)
test_X_scaled = apply_full_scaling(test_X_df, scalers)

In [18]:
train_X_scaled.shape, train_y_df.shape, test_X_scaled.shape, test_y_df.shape

((458928, 200), (458928, 1), (695941, 200), (695941, 1))

In [19]:
train_y_df['label'].value_counts(), test_y_df['label'].value_counts()

(label
 0    290941
 1    167987
 Name: count, dtype: int64,
 label
 0    448511
 1    247430
 Name: count, dtype: int64)

In [20]:
# 平衡数据集

def create_balanced_dataset(X, y, sample_size=10000, random_state=42) :
    """
    生成平衡的数据集
    """
    # 平衡采样
    X.reset_index(drop=True, inplace=True)
    y.reset_index(drop=True, inplace=True)
    each_class_size = sample_size // 2
    sampled = []
    for class_label in [0, 1]:
        class_indices = y[y['label'] == class_label].index
        n_samples = min(each_class_size, len(class_indices))
        sampled.extend(np.random.choice(class_indices, n_samples, replace=False))
    
    # 划分数据集
    X_balanced = X.loc[sampled]
    y_balanced = y.loc[sampled]

    return X_balanced, y_balanced
        

sample_size = 50000
X_train_balanced, y_train_balanced = create_balanced_dataset(train_X_scaled, train_y_df, sample_size)
X_test_balanced, y_test_balanced = create_balanced_dataset(test_X_scaled, test_y_df, sample_size)

# 验证输出形状
print(f"训练集: {X_train_balanced.shape}, 测试集: {X_test_balanced.shape}")
print("类别分布：")
print("Train:", y_train_balanced['label'].value_counts())
print("Test:", y_test_balanced['label'].value_counts())

训练集: (50000, 200), 测试集: (50000, 200)
类别分布：
Train: label
0    25000
1    25000
Name: count, dtype: int64
Test: label
0    25000
1    25000
Name: count, dtype: int64


In [21]:
# 模型列表
models = [
    ('Logistic Regression (SGD)',
    SGDClassifier(
        loss='log_loss',
        penalty='l2',
        alpha=1e-4,          # 正则化参数
        max_iter=1000, 
        tol=1e-3,
        n_jobs=-1,              # 并行计算
        random_state=42
    )),

    ('Linear SVM',
    LinearSVC(
        C=1.0,                  # 正则化参数
        dual=False,             # 避免大数据集的求解问题
        max_iter=2000,          # 迭代次数增加以防止收敛失败
        tol=1e-4,
        random_state=42
    )),

    ('XGBoost',
    XGBClassifier(
        objective='binary:logistic',  # 二分类问题
        n_estimators=500,
        # early_stopping_rounds=50,  # 早停
        learning_rate=0.05, 
        max_depth=6, 
        subsample=0.8, 
        colsample_bytree=0.8, 
        tree_method='hist',           # 适用于中等数据
        n_jobs=-1,                    # 并行加速
        random_state=42
    )),

    ('MLP',
    MLPClassifier(
        hidden_layer_sizes=(128, 64), # 两层隐藏层，神经元数 128 → 64
        activation='relu',            # ReLU 激活函数
        solver='adam',                # Adam 优化
        alpha=1e-4,                   # L2 正则化
        batch_size=128,               # 小批量梯度下降
        learning_rate_init=0.001,      # 学习率
        max_iter=500,                 # 训练 500 轮
        early_stopping=True,          # 提前停止，防止过拟合
        n_iter_no_change=10,          # 10 轮无提升则停止
        random_state=42
    ))
]

X_train = X_train_balanced.copy()
y_train = np.squeeze(y_train_balanced)
X_test = X_test_balanced.copy()
y_test = np.squeeze(y_test_balanced)

# 训练评估
results = []
for name, model in models:
    start_time = time.time()
    
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    report = classification_report(y_test, y_pred, output_dict=True)
    
    elapsed = time.time() - start_time
    result = {
        'Model': name,
        'Accuracy': report['accuracy'],
        'Precision': report['1']['precision'],
        'Recall': report['1']['recall'],
        'F1': report['1']['f1-score']
    }
    results.append(result)
    print(f"--- {name} ---")
    print("Time elapsed: ", elapsed, "(s)")
    report_df = pd.DataFrame(report).transpose()
    print(report_df)
    
result_df = pd.DataFrame(results)
result_df.set_index('Model', inplace=True)
result_df.sort_values('F1', ascending=False, inplace=True)
result_df

--- Logistic Regression (SGD) ---
Time elapsed:  2.310307025909424 (s)
              precision   recall  f1-score      support
0              0.743386  0.46756  0.574060  25000.00000
1              0.611652  0.83860  0.707369  25000.00000
accuracy       0.653080  0.65308  0.653080      0.65308
macro avg      0.677519  0.65308  0.640714  50000.00000
weighted avg   0.677519  0.65308  0.640714  50000.00000
--- Linear SVM ---
Time elapsed:  6.683095216751099 (s)
              precision   recall  f1-score      support
0              0.759179  0.43920  0.556471  25000.00000
1              0.605482  0.86068  0.710871  25000.00000
accuracy       0.649940  0.64994  0.649940      0.64994
macro avg      0.682330  0.64994  0.633671  50000.00000
weighted avg   0.682330  0.64994  0.633671  50000.00000
--- XGBoost ---
Time elapsed:  16.049792051315308 (s)
              precision   recall  f1-score      support
0              0.736806  0.63884  0.684335  25000.00000
1              0.681224  0.77180  0

Unnamed: 0_level_0,Accuracy,Precision,Recall,F1
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
XGBoost,0.70532,0.681224,0.7718,0.723689
Linear SVM,0.64994,0.605482,0.86068,0.710871
Logistic Regression (SGD),0.65308,0.611652,0.8386,0.707369
MLP,0.67724,0.658886,0.735,0.694865


In [23]:
cur_time = datetime.now().strftime('%Y-%m-%d %H:%M')
args = {
    'Time': cur_time,
    'REC_CNT': REC_CNT,
    'PRED_CNT': PRED_CNT,
    'IS_BINARY': IS_BINARY,
    'stock_cnt': len(code_list),
    'sample_size': sample_size,
    'train_start': '2021-11-1 00:00:00',
    'train_end': '2021-11-2 23:59:59',
    'test_start': '2021-11-3 00:00:00',
    'test_end': '2021-11-5 23:59:59'
}
result_df = pd.concat([pd.Series(args), result_df])
result_df.to_csv(f'results_{cur_time}.csv', index=True)