In [120]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ===================== 定义 embedding 层 =====================
shop_embedding = nn.Embedding(1000, 8)       # shop_id embedding
product_embedding = nn.Embedding(500, 16)   # product_id embedding
category_embedding = nn.Embedding(100, 32)  # 物品类别 embedding

age_embedding = nn.Embedding(10, 8)       # 年龄 embedding
gender_embedding = nn.Embedding(3, 4)     # 性别 embedding
city_embedding = nn.Embedding(1000, 16)   # 城市 embedding

# ===================== DIN 模型定义 =====================
class DIN(nn.Module):
    def __init__(self, shop_emb_dim=8, product_emb_dim=16):
        super(DIN, self).__init__()

        self.shop_emb_dim = shop_emb_dim
        self.product_emb_dim = product_emb_dim

        # 注意力网络分别定义
        # 写成 shop_emb_dim * 2 是因为 Attention 网络输入是 [当前目标, 历史行为] 两个向量拼接后的维度。
        self.shop_attention_fc = nn.Sequential(
            nn.Linear(shop_emb_dim * 2, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        # 写成 product_emb_dim * 2 是因为 Attention 网络输入是 [当前目标, 历史行为] 两个向量拼接后的维度。
        self.product_attention_fc = nn.Sequential(
            nn.Linear(product_emb_dim * 2, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

        # 拼接后的 MLP 输入维度：
        # age(8) + gender(4) + city(16) + active_days(1) + shop_interest(8) + product_interest(16)
        self.mlp = nn.Sequential(
            nn.Linear(8 + 4 + 16 + 1 + shop_emb_dim + product_emb_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def attention(self, target_vec, history_vecs, attention_fc):
        target_expand = target_vec.unsqueeze(0).expand_as(history_vecs)  # (100, d)
        att_input = torch.cat([target_expand, history_vecs], dim=-1)     # (100, 2d)
        att_weights = attention_fc(att_input)                            # (100, 1)
        att_weights = F.softmax(att_weights, dim=0)                      # (100, 1)
        weighted_history = att_weights * history_vecs                    # (100, d)
        return weighted_history.sum(dim=0)                               # (d,)

    
    def forward(self, age_vec, gender_vec, city_vec, active_days,
                target_shop_vec, history_shop_vec,
                target_product_vec, history_product_vec):

        shop_interest_vec = self.attention(target_shop_vec, history_shop_vec, self.shop_attention_fc)
        product_interest_vec = self.attention(target_product_vec, history_product_vec, self.product_attention_fc)

        features = torch.cat([
            age_vec, gender_vec, city_vec, active_days,
            shop_interest_vec, product_interest_vec
        ], dim=-1)

        output = self.mlp(features)
        return torch.sigmoid(output)



# ===================== 数据准备 =====================
# 用户特征
age = torch.tensor([3])
gender = torch.tensor([1])
city = torch.tensor([25])
shop_category = torch.randint(0, 1000, (100,))  # 用户历史浏览的 shop_id
product_category = torch.randint(0, 500, (100,))  # 用户历史浏览的 product_id
target_shop = torch.tensor([100])     # 当前目标 shop_id
target_product = torch.tensor([150])  # 当前目标 product_id

active_days = torch.tensor([120.0])
normalized_active_days = active_days / 365  # 归一化活跃天数

# 获取 embedding 向量
age_vec = age_embedding(age)  # (1, 8)
gender_vec = gender_embedding(gender)  # (1, 4)
city_vec = city_embedding(city)  # (1, 16)
target_shop_vec = shop_embedding(target_shop)  # (1, 8)
history_shop_vec = shop_embedding(shop_category)  # (100, 8)
target_product_vec = product_embedding(target_product)  # (1, 16)
history_product_vec = product_embedding(product_category)  # (100, 16)

# 扩展和准备特征
normalized_active_days_exp = normalized_active_days.expand(1, 1)  # (1, 1)

# ===================== 模型预测 =====================
model = DIN()
prediction = model(
    age_vec.squeeze(0), 
    gender_vec.squeeze(0), 
    city_vec.squeeze(0), 
    normalized_active_days_exp.squeeze(0), 
    target_shop_vec.squeeze(0), 
    history_shop_vec, 
    target_product_vec.squeeze(0), 
    history_product_vec
)

print(f"预测的点击概率: {prediction.item():.4f}")

预测的点击概率: 0.5118
