# 01 序列兴趣建模：DIN 工业级实战示例

- **核心功能**：严谨的特征对齐、早停机制、模型导出 (ONNX)。
- **数据**：Amazon-Electronics Sample。
- **代码风格**：高度参数化，易于迁移到生产环境。

In [None]:
import os
import numpy as np
import pandas as pd
import torch
from torch_rechub.basic.features import SparseFeature, SequenceFeature
from torch_rechub.models.ranking import DIN
from torch_rechub.trainers import CTRTrainer
from torch_rechub.utils.data import DataGenerator, create_seq_features, df_to_dict

# 配置参数
SEED = 2022
DEVICE = "cpu"
DATASET_PATH = "../examples/ranking/data/amazon-electronics/amazon_electronics_sample.csv"
SEQ_MAX_LEN = 50
EPOCH = 2
BATCH_SIZE = 4096
LR = 1e-3
WEIGHT_DECAY = 1e-3
EARLYSTOP_PATIENCE = 4
EXPORT_ONNX = False
ONNX_PATH = "din.onnx"

torch.manual_seed(SEED)
print("DATASET_PATH:", os.path.abspath(DATASET_PATH))

## 数据处理与严谨的词表计算

In [None]:
data = pd.read_csv(DATASET_PATH)
train_df, val_df, test_df = create_seq_features(data, seq_feature_col=["item_id", "cate_id"], max_len=SEQ_MAX_LEN, drop_short=0, shuffle=True)

def max_from_list_col(df: pd.DataFrame, col: str) -> int:
    arr = np.asarray(df[col].tolist())
    return int(arr.max())

n_users = int(max(train_df["user_id"].max(), val_df["user_id"].max(), test_df["user_id"].max()))
n_items = int(max(train_df["target_item"].max(), val_df["target_item"].max(), test_df["target_item"].max(),
                  max_from_list_col(train_df, "history_item"), max_from_list_col(val_df, "history_item")))
n_cates = int(max(train_df["target_cate"].max(), val_df["target_cate"].max(), test_df["target_cate"].max(),
                  max_from_list_col(train_df, "history_cate"), max_from_list_col(val_df, "history_cate")))

print({"n_users": n_users, "n_items": n_items, "n_cates": n_cates})

train_y, val_y, test_y = train_df["label"], val_df["label"], test_df["label"]
train_x, val_x, test_x = df_to_dict(train_df.drop(columns="label")), df_to_dict(val_df.drop(columns="label")), df_to_dict(test_df.drop(columns="label"))

## 构造特征与模型初始化

In [None]:
target_features = [
    SparseFeature("target_item", vocab_size=n_items + 1, embed_dim=64),
    SparseFeature("target_cate", vocab_size=n_cates + 1, embed_dim=64),
]
features = target_features + [SparseFeature("user_id", vocab_size=n_users + 1, embed_dim=64)]
history_features = [
    SequenceFeature("history_item", vocab_size=n_items + 1, embed_dim=64, pooling="concat", shared_with="target_item"),
    SequenceFeature("history_cate", vocab_size=n_cates + 1, embed_dim=64, pooling="concat", shared_with="target_cate"),
]

dg = DataGenerator(train_x, train_y)
train_dl, val_dl, test_dl = dg.generate_dataloader(x_val=val_x, y_val=val_y, x_test=test_x, y_test=test_y, batch_size=BATCH_SIZE)

model = DIN(features=features, history_features=history_features, target_features=target_features,
            mlp_params={"dims": [256, 128]}, attention_mlp_params={"dims": [256, 128]})

## 训练、评估与导出

In [None]:
ctr_trainer = CTRTrainer(model, optimizer_params={"lr": LR, "weight_decay": WEIGHT_DECAY}, 
                         n_epoch=EPOCH, earlystop_patience=EARLYSTOP_PATIENCE, device=DEVICE, model_path="./")

ctr_trainer.fit(train_dl, val_dl)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dl)
print(f"Test AUC: {auc}")

if EXPORT_ONNX:
    try:
        ctr_trainer.export_onnx(ONNX_PATH, verbose=False, device=DEVICE)
        print("Model exported to:", ONNX_PATH)
    except Exception as e:
        print("ONNX export failed:", repr(e))