In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 用户塔网络
class UserTower(nn.Module):
    def __init__(self, user_feature_size, embedding_dim):
        super(UserTower, self).__init__()
        self.embedding = nn.Linear(user_feature_size, embedding_dim)
        self.fc1 = nn.Linear(embedding_dim, embedding_dim)
    
    def forward(self, user_features):
        # 用户特征 -> 低维嵌入
        user_embedding = F.relu(self.embedding(user_features))
        user_embedding = F.relu(self.fc1(user_embedding))
        return user_embedding

# 视频塔网络
class VideoTower(nn.Module):
    def __init__(self, video_feature_size, embedding_dim):
        super(VideoTower, self).__init__()
        self.embedding = nn.Linear(video_feature_size, embedding_dim)
        self.fc1 = nn.Linear(embedding_dim, embedding_dim)
    
    def forward(self, video_features):
        # 视频特征 -> 低维嵌入
        video_embedding = F.relu(self.embedding(video_features))
        video_embedding = F.relu(self.fc1(video_embedding))
        return video_embedding

# 双塔召回模型
class TwoTowerModel(nn.Module):
    def __init__(self, user_feature_size, video_feature_size, embedding_dim):
        super(TwoTowerModel, self).__init__()
        self.user_tower = UserTower(user_feature_size, embedding_dim)
        self.video_tower = VideoTower(video_feature_size, embedding_dim)
    
    def forward(self, user_features, video_features):
        user_embedding = self.user_tower(user_features)
        video_embedding = self.video_tower(video_features)
        # 计算用户与视频的相似度（余弦相似度）
        similarity = F.cosine_similarity(user_embedding, video_embedding)
        return similarity

In [11]:
# 模拟用户特征和视频特征
import numpy as np

# 随机生成用户特征和视频特征数据
num_users = 1000
num_videos = 2000
user_feature_size = 32  # 用户特征维度
video_feature_size = 32  # 视频特征维度
embedding_dim = 16  # 嵌入维度

# 随机生成用户和视频的特征
user_features = torch.FloatTensor(np.random.rand(num_users, user_feature_size))
video_features = torch.FloatTensor(np.random.rand(num_videos, video_feature_size))

# 模拟标签 (0-1的相似性标签)
labels = torch.FloatTensor(np.random.randint(0, 2, [num_users, num_videos]))

# 模型初始化
model = TwoTowerModel(user_feature_size, video_feature_size, embedding_dim)
criterion = nn.BCELoss()  # 使用二分类交叉熵作为损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练
for epoch in range(10):
    optimizer.zero_grad()
    
    # 随机选择视频进行训练
    similarities = torch.tensor([])
    for _ in range(num_videos):
        video_idx = torch.randint(0, num_videos, (num_users,))
        selected_video_features = video_features[video_idx]
        
        # 计算相似度
        similarity = model(user_features, selected_video_features)
        similarities = torch.concat([similarities, similarity.unsqueeze(0).T], dim=1)
    
    # 计算损失
    print(similarities.shape, labels.shape)
    loss = criterion(similarities, labels)
    
    # 反向传播和优化
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")


torch.Size([1000, 2000]) torch.Size([1000, 2000])
Epoch 1, Loss: 0.7250649929046631
torch.Size([1000, 2000]) torch.Size([1000, 2000])
Epoch 2, Loss: 0.7203171253204346
torch.Size([1000, 2000]) torch.Size([1000, 2000])
Epoch 3, Loss: 0.7178353667259216
torch.Size([1000, 2000]) torch.Size([1000, 2000])
Epoch 4, Loss: 0.7142857313156128
torch.Size([1000, 2000]) torch.Size([1000, 2000])
Epoch 5, Loss: 0.7114763855934143
torch.Size([1000, 2000]) torch.Size([1000, 2000])
Epoch 6, Loss: 0.7093394994735718
torch.Size([1000, 2000]) torch.Size([1000, 2000])
Epoch 7, Loss: 0.7078996300697327
torch.Size([1000, 2000]) torch.Size([1000, 2000])
Epoch 8, Loss: 0.7064034938812256
torch.Size([1000, 2000]) torch.Size([1000, 2000])
Epoch 9, Loss: 0.7048524022102356
torch.Size([1000, 2000]) torch.Size([1000, 2000])
Epoch 10, Loss: 0.703533411026001


In [16]:
import torch
import torch.nn as nn
import torch.optim as optim

# 定义用户塔
class UserTower(nn.Module):
    def __init__(self, user_feature_dim, embedding_dim):
        super(UserTower, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(user_feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, embedding_dim)
        )
        
    def forward(self, user_features):
        return self.mlp(user_features)

# 定义视频塔
class VideoTower(nn.Module):
    def __init__(self, video_feature_dim, embedding_dim):
        super(VideoTower, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(video_feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, embedding_dim)
        )
        
    def forward(self, video_features):
        return self.mlp(video_features)

# 定义双塔排序模型
class TwoTowerRankingModel(nn.Module):
    def __init__(self, user_feature_dim, video_feature_dim, embedding_dim):
        super(TwoTowerRankingModel, self).__init__()
        self.user_tower = UserTower(user_feature_dim, embedding_dim)
        self.video_tower = VideoTower(video_feature_dim, embedding_dim)
    
    def forward(self, user_features, video_features):
        # 编码用户和视频特征
        user_embedding = self.user_tower(user_features)
        video_embedding = self.video_tower(video_features)
        
        # 计算匹配分数 (内积相似度)
        score = torch.sum(user_embedding * video_embedding, dim=1)
        print(user_embedding.shape, video_embedding.shape, score.shape)
        return score

# 假设用户和视频的特征维度
user_feature_dim = 10  # 用户特征维度
video_feature_dim = 15  # 视频特征维度
embedding_dim = 64      # 嵌入维度

# 实例化模型
model = TwoTowerRankingModel(user_feature_dim, video_feature_dim, embedding_dim)

# 损失函数和优化器
criterion = nn.BCEWithLogitsLoss()  # 二分类损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 示例输入数据
batch_size = 32
user_features = torch.randn(batch_size, user_feature_dim)  # 随机用户特征
video_features = torch.randn(batch_size, video_feature_dim)  # 随机视频特征
labels = torch.randint(0, 2, (batch_size,)).float()  # 随机0/1标签

# 训练一步
optimizer.zero_grad()
scores = model(user_features, video_features)
loss = criterion(scores, labels)
loss.backward()
optimizer.step()

print("Loss:", loss.item())

torch.Size([32, 64]) torch.Size([32, 64]) torch.Size([32])
Loss: 0.6499403715133667
