In [1]:
import polars as pl
import numpy as np

features = ["feature_06", "feature_36", "feature_04", "feature_56", "feature_19", "feature_59", 
            "feature_25", "feature_45", "feature_60", "feature_58", "feature_39", "feature_66",
            "feature_08", "feature_68", "feature_52", "feature_70", "feature_48", "feature_24", 
            "feature_65", "feature_74"]

class CONFIG:
    target_col = "responder_6"
    start_dt = 1000
    selected_columns = ["date_id", 'time_id', 'symbol_id', 'responder_6', 'weight'] + features


def load_and_split_data(parquet_path, config):
    df = pl.scan_parquet(parquet_path).select(
        config.selected_columns
    ).select(
        pl.int_range(pl.len(), dtype=pl.UInt32).alias("id"),
        pl.all(),
    ).filter(
        pl.col("date_id").gt(config.start_dt)
    ).collect()
    
    # Create lag feature
    df = df.sort(['symbol_id', 'date_id', 'time_id'])
    
    
   
    print(f"日期範圍: {df['date_id'].min()} - {df['date_id'].max()}")
    return df

# 使用示例
df = load_and_split_data(
    "/kaggle/input/jane-street-real-time-market-data-forecasting/train.parquet",
    CONFIG
)
df = df.to_pandas()

日期範圍: 1001 - 1698


In [2]:
import multiprocessing
multiprocessing.set_start_method('spawn')

In [3]:
!pip install pytorch-forecasting
!pip --quiet install pytorch_lightning

In addition, using fork() with Python in general is a recipe for mysterious
deadlocks and crashes.

The most likely reason you are seeing this error is because you are using the
multiprocessing module on Linux, which uses fork() by default. This will be
fixed in Python 3.14. Until then, you want to use the "spawn" context instead.

See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details.

  pid, fd = os.forkpty()


Collecting pytorch-forecasting
  Downloading pytorch_forecasting-1.2.0-py3-none-any.whl.metadata (13 kB)
Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)
  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Downloading pytorch_forecasting-1.2.0-py3-none-any.whl (181 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m181.9/181.9 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning-2.4.0-py3-none-any.whl (810 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m811.0/811.0 kB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning, pytorch-forecasting
Successfully installed lightning-2.4.0 pytorch-forecasting-1.2.0


In [4]:
import copy
from pathlib import Path
import warnings

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import pandas as pd
import torch

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

columns_with_na = df.columns[df.isna().any()].tolist()
for column in columns_with_na:
    df[column] = df.groupby('symbol_id')[column].ffill()
    df[column] = df.groupby('symbol_id')[column].bfill()
    df[column] = df[column].fillna(0)

def create_sequential_id_with_duplicates(df):
    """
    將 date_id 和 time_id 合併成連續的 ID，相同的組合會得到相同的 ID
    
    Parameters:
    df (pandas.DataFrame): 包含 date_id、time_id 的 DataFrame
    
    Returns:
    pandas.DataFrame: 添加了 sequential_id 的 DataFrame
    """
    # 獲取唯一的 (date_id, time_id) 組合並排序
    unique_combinations = (
        df[['date_id', 'time_id']]
        .drop_duplicates()
        .sort_values(['date_id', 'time_id'])
    )
    
    # 為唯一組合創建 ID 映射字典
    id_mapping = {
        (date_id, time_id): i 
        for i, (date_id, time_id) in enumerate(
            zip(unique_combinations['date_id'], 
                unique_combinations['time_id'])
        )
    }
    
    # 為每一行添加 sequential_id
    df['sequential_id'] = df.apply(
        lambda row: id_mapping[(row['date_id'], row['time_id'])], 
        axis=1
    )
    
    return df.sort_values(['symbol_id', 'sequential_id'])
df = create_sequential_id_with_duplicates(df)
# 1. 獲取唯一的date_id並排序
unique_dates = sorted(df['date_id'].unique())

# 2. 計算10%分位點的索引
cutoff_index = int(len(unique_dates) * 0.8)  # 取90%位置，即最後10%的起始點

# 3. 獲取切分的date_id
cutoff_date = unique_dates[cutoff_index]

# 4. 找出對應的最小sequential_id
cutoff_sequential_id = df[df['date_id'] >= cutoff_date]['sequential_id'].min()

print(f"切分date_id: {cutoff_date}")
print(f"對應的sequential_id: {cutoff_sequential_id}")

切分date_id: 1559
對應的sequential_id: 540144


In [5]:
df['symbol_id'] = df['symbol_id'].astype(str)
max_prediction_length = 10
max_encoder_length = 500


training = TimeSeriesDataSet(
    df[lambda x: x.sequential_id <= cutoff_sequential_id],
    time_idx="sequential_id",
    target="responder_6",
    group_ids=["symbol_id"],
    min_encoder_length=max_encoder_length // 2,  # keep encoder length long (as it is in the validation set)
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["symbol_id"],
    time_varying_known_reals=["date_id", "time_id", 'sequential_id'],
    time_varying_unknown_reals=features+['responder_6'],
    weight = "weight",
    #target_normalizer=GroupNormalizer(
        #groups=["symbol_id"], transformation="softplus"
    #),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
)

# create validation set (predict=True) which means to predict the last max_prediction_length points in time
# for each series
validation = TimeSeriesDataSet.from_dataset(training, df, predict=True, stop_randomization=True)

# create dataloaders for model
batch_size = 128  # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)