In [1]:
import movielens_data
MovieLensDataset=movielens_data.MovieLensDataset

In [2]:
import os
import sys
sys.path.append("e:\\seq")

In [44]:
from model.bert4rec import *
from config import config  
from tqdm import tqdm

In [4]:
data_dir = 'e:\seq\data\ml-1m'
train_dataset = MovieLensDataset(data_dir, max_len=50, split='train')
test_dataset = MovieLensDataset(data_dir, max_len=50, split='test')

In [12]:

model = BERT4Rec(
        item_num=train_dataset.num_items,
        max_len=config.max_len,
        hidden_units=config.hidden_units,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        dropout=config.dropout
    )



In [15]:
train_dataset.num_items
test_dataset.num_items

3884

In [13]:
model.item_num

3884

In [29]:
import torch.nn.functional as F
def compute_loss(logits, labels):
    # logits shape: [batch_size, seq_len, vocab_size]
    # labels shape: [batch_size, seq_len]
    
    # Reshape logits and labels
    logits = logits.view(-1, logits.size(-1))  # [batch_size * seq_len, vocab_size]
    labels = labels.view(-1)  # [batch_size * seq_len]
    
    # Create a mask to ignore padding tokens
    mask = (labels > 0).float()
    
    # Compute cross entropy loss
    loss = F.cross_entropy(logits, labels, reduction='none')
    
    # Apply mask and compute mean loss
    loss = (loss * mask).sum() / mask.sum()
    
    return loss

In [33]:
pred=model(train_dataset[1][0].unsqueeze(0))
lab=train_dataset[1][1].unsqueeze(0)
compute_loss(pred,lab)

tensor(8.1502, grad_fn=<DivBackward0>)

In [41]:
pred=model(test_dataset[1][0].unsqueeze(0))
pred=pred[:,-1,:]
torch.argmax(pred,dim=1)

tensor([2546])

In [56]:
import torch
import numpy as np
from torch.utils.data import DataLoader

def calculate_metrics_batch(model, test_dataset, batch_size=128, k_values=[1, 5, 10],device="cpu"):
    model.eval()
    hrs = {k: [] for k in k_values}
    ndcgs = {k: [] for k in k_values}
    
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    with torch.no_grad():
        for batch in tqdm(test_dataloader):
            interactions, candidates, labels = batch
            interactions = interactions.to(device)
            candidates = candidates.to(device)
            # 获取模型预测
            logits = model(interactions)
            
            # 我们只关心最后一个时间步的预测
            last_logits = logits[:, -1, :]
            
            batch_size = interactions.size(0)
            for i in range(batch_size):
                # 获取候选项的得分
                candidate_scores = last_logits[i][candidates[i]]
                
                # 计算排序
                _, indices = torch.sort(candidate_scores, descending=True)
                
                # 找到正样本的排名
                rank = (indices == 0).nonzero().item() + 1  # +1 因为索引从0开始
                
                # 计算 HR 和 NDCG
                for k in k_values:
                    hrs[k].append(1 if rank <= k else 0)
                    ndcgs[k].append(1 / np.log2(rank + 1) if rank <= k else 0)
    
    # 计算平均值
    for k in k_values:
        hrs[k] = np.mean(hrs[k])
        ndcgs[k] = np.mean(ndcgs[k])
    
    return hrs, ndcgs

# 使用示例
model.eval()  # 确保模型处于评估模式
hrs, ndcgs = calculate_metrics_batch(model, test_dataset)

# 打印结果
for k in hrs.keys():
    print(f"HR@{k}: {hrs[k]:.4f}, NDCG@{k}: {ndcgs[k]:.4f}")

100%|██████████| 48/48 [00:11<00:00,  4.21it/s]

HR@1: 0.1770, NDCG@1: 0.1770
HR@5: 0.5061, NDCG@5: 0.3481
HR@10: 0.6545, NDCG@10: 0.3962





In [57]:
import time
import torch.optim as optim
import os

def train(model, train_dataset, test_dataset, device, 
          batch_size=128, num_epochs=100, lr=1e-3, 
          eval_steps=1000, patience=5, k_values=[1, 5, 10]):
    
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    best_hr = 0
    best_epoch = -1
    patience_counter = 0
    global_step = 0
    
    # Save initial model
    torch.save(model.state_dict(), 'initial_model.pth')
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        start_time = time.time()
        
        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            interactions, labels = batch
            interactions, labels = interactions.to(device), labels.to(device)
            
            optimizer.zero_grad()
            logits = model(interactions)
            loss = compute_loss(logits, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            global_step += 1
            
            if global_step % eval_steps == 0:
                avg_loss = total_loss / eval_steps
                print(f"\nStep {global_step}, Average Loss: {avg_loss:.4f}")
                
                # Evaluate on test set
                hrs, ndcgs = calculate_metrics_batch(model, test_dataset, batch_size, k_values, device)
                
                print("Evaluation Results:")
                for k in k_values:
                    print(f"HR@{k}: {hrs[k]:.4f}, NDCG@{k}: {ndcgs[k]:.4f}")
                
                # Check for improvement
                if hrs[k_values[0]] > best_hr:
                    best_hr = hrs[k_values[0]]
                    best_epoch = epoch
                    patience_counter = 0
                    torch.save(model.state_dict(), 'best_model.pth')
                    print("New best model saved!")
                else:
                    patience_counter += 1
                
                # Early stopping
                if patience_counter >= patience:
                    print(f"No improvement for {patience} evaluations. Early stopping.")
                    break
                
                total_loss = 0
                model.train()
        
        end_time = time.time()
        print(f"Epoch {epoch+1} completed in {end_time - start_time:.2f} seconds")
        
        if patience_counter >= patience:
            break
    
    print(f"Training completed. Best HR@{k_values[0]}: {best_hr:.4f} at epoch {best_epoch+1}")
    
    # Load best model if exists, otherwise keep the current model
    if os.path.exists('best_model.pth'):
        model.load_state_dict(torch.load('best_model.pth'))
        print("Loaded the best model.")
    elif os.path.exists('initial_model.pth'):
        model.load_state_dict(torch.load('initial_model.pth'))
        print("No improvement during training. Loaded the initial model.")
    else:
        print("No saved model found. Returning the current model state.")
    
    return model

In [58]:
train(model,train_dataset,test_dataset,device="cuda",batch_size=64,num_epochs=30)

Epoch 1/30: 100%|██████████| 95/95 [00:02<00:00, 40.38it/s]


Epoch 1 completed in 2.35 seconds


Epoch 2/30: 100%|██████████| 95/95 [00:02<00:00, 43.48it/s]


Epoch 2 completed in 2.19 seconds


Epoch 3/30: 100%|██████████| 95/95 [00:02<00:00, 43.81it/s]


Epoch 3 completed in 2.17 seconds


Epoch 4/30: 100%|██████████| 95/95 [00:02<00:00, 43.19it/s]


Epoch 4 completed in 2.20 seconds


Epoch 5/30: 100%|██████████| 95/95 [00:02<00:00, 38.72it/s]


Epoch 5 completed in 2.46 seconds


Epoch 6/30: 100%|██████████| 95/95 [00:02<00:00, 40.70it/s]


Epoch 6 completed in 2.34 seconds


Epoch 7/30: 100%|██████████| 95/95 [00:02<00:00, 44.05it/s]


Epoch 7 completed in 2.16 seconds


Epoch 8/30: 100%|██████████| 95/95 [00:02<00:00, 44.79it/s]


Epoch 8 completed in 2.12 seconds


Epoch 9/30: 100%|██████████| 95/95 [00:02<00:00, 42.43it/s]


Epoch 9 completed in 2.24 seconds


Epoch 10/30: 100%|██████████| 95/95 [00:02<00:00, 44.89it/s]


Epoch 10 completed in 2.12 seconds


Epoch 11/30:  47%|████▋     | 45/95 [00:00<00:01, 45.50it/s]


Step 1000, Average Loss: 0.2549


100%|██████████| 95/95 [00:02<00:00, 39.55it/s]
Epoch 11/30:  58%|█████▊    | 55/95 [00:03<00:05,  7.78it/s]

Evaluation Results:
HR@1: 0.1813, NDCG@1: 0.1813
HR@5: 0.4902, NDCG@5: 0.3418
HR@10: 0.6381, NDCG@10: 0.3896
New best model saved!


Epoch 11/30: 100%|██████████| 95/95 [00:04<00:00, 20.51it/s]


Epoch 11 completed in 4.63 seconds


Epoch 12/30: 100%|██████████| 95/95 [00:02<00:00, 44.23it/s]


Epoch 12 completed in 2.15 seconds


Epoch 13/30: 100%|██████████| 95/95 [00:02<00:00, 44.73it/s]


Epoch 13 completed in 2.13 seconds


Epoch 14/30: 100%|██████████| 95/95 [00:02<00:00, 44.80it/s]


Epoch 14 completed in 2.12 seconds


Epoch 15/30: 100%|██████████| 95/95 [00:02<00:00, 44.23it/s]


Epoch 15 completed in 2.15 seconds


Epoch 16/30: 100%|██████████| 95/95 [00:02<00:00, 44.30it/s]


Epoch 16 completed in 2.15 seconds


Epoch 17/30: 100%|██████████| 95/95 [00:02<00:00, 44.65it/s]


Epoch 17 completed in 2.13 seconds


Epoch 18/30: 100%|██████████| 95/95 [00:02<00:00, 43.98it/s]


Epoch 18 completed in 2.16 seconds


Epoch 19/30: 100%|██████████| 95/95 [00:02<00:00, 44.77it/s]


Epoch 19 completed in 2.12 seconds


Epoch 20/30: 100%|██████████| 95/95 [00:02<00:00, 43.80it/s]


Epoch 20 completed in 2.17 seconds


Epoch 21/30: 100%|██████████| 95/95 [00:02<00:00, 44.68it/s]


Epoch 21 completed in 2.13 seconds


Epoch 22/30:   4%|▍         | 4/95 [00:00<00:02, 35.94it/s]


Step 2000, Average Loss: 0.0242


100%|██████████| 95/95 [00:02<00:00, 42.08it/s]
Epoch 22/30:  14%|█▎        | 13/95 [00:02<00:15,  5.38it/s]

Evaluation Results:
HR@1: 0.1652, NDCG@1: 0.1652
HR@5: 0.4960, NDCG@5: 0.3351
HR@10: 0.6478, NDCG@10: 0.3844


Epoch 22/30: 100%|██████████| 95/95 [00:04<00:00, 21.43it/s]


Epoch 22 completed in 4.44 seconds


Epoch 23/30: 100%|██████████| 95/95 [00:02<00:00, 45.16it/s]


Epoch 23 completed in 2.11 seconds


Epoch 24/30: 100%|██████████| 95/95 [00:02<00:00, 44.80it/s]


Epoch 24 completed in 2.12 seconds


Epoch 25/30: 100%|██████████| 95/95 [00:02<00:00, 44.39it/s]


Epoch 25 completed in 2.14 seconds


Epoch 26/30: 100%|██████████| 95/95 [00:02<00:00, 44.42it/s]


Epoch 26 completed in 2.14 seconds


Epoch 27/30: 100%|██████████| 95/95 [00:02<00:00, 44.40it/s]


Epoch 27 completed in 2.14 seconds


Epoch 28/30: 100%|██████████| 95/95 [00:02<00:00, 44.78it/s]


Epoch 28 completed in 2.12 seconds


Epoch 29/30: 100%|██████████| 95/95 [00:02<00:00, 44.69it/s]


Epoch 29 completed in 2.13 seconds


Epoch 30/30: 100%|██████████| 95/95 [00:02<00:00, 44.53it/s]

Epoch 30 completed in 2.14 seconds
Training completed. Best HR@1: 0.1813 at epoch 11
Loaded the best model.





BERT4Rec(
  (item_emb): Embedding(3886, 256, padding_idx=0)
  (pos_emb): Embedding(50, 256)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (out): Linear(in_features=256, out_features=3885, bias=True)
)

In [32]:
test_dataset[1]

(tensor([1246, 3741,  586, 1960, 2503,  454, 2848, 1569,  477,  162,  377, 3350,
         1492,  346,   21, 1386, 3188, 2285, 1938, 2210, 1350,  643, 2359, 1514,
         1733, 1352, 2422, 1365,  771,  164,  456, 1741, 3039, 2813,  365,  439,
         1557, 2560, 1645, 3189,  728, 1934, 2058,  290,   94,  431, 1506, 1642,
         1849, 3885]),
 tensor([1849, 1223, 2879,  771, 3011, 3636,  524, 3695, 2328,  552, 2648, 2297,
         1167,   62,  976, 1371, 1643,   39,  346, 2321, 3555, 3653, 2624, 3635,
         3002, 1213, 2072, 1392, 1550, 3399,  471,   57, 2446, 1274,  202,  377,
         3828, 2804, 1886, 1269, 1203, 1280,  252, 1266, 2724, 2645, 2560,  494,
          870,  908, 1120,  353, 2180,  316, 1376, 1459, 2773, 2729, 3185, 3546,
         2400, 2791, 1523, 1827,  231, 2993, 3570,  490, 2745, 3382, 1275, 2861,
         1943, 2883, 3631, 1065, 3489, 1879, 2004, 2960, 1655, 3033, 2702, 2723,
         3475,  221, 1227, 2900, 2632, 2847,  585, 3117,  350,  586,  555,  590,
     