# Import library

In [1]:
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.autograd import Variable
import pickle
import numpy as np
import time
import random
from collections import defaultdict
import torch.utils.data
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from math import sqrt
import datetime
import argparse
import os
from tqdm import tqdm
import pandas as pd
import logging
from datetime import datetime

# Define Class

In [2]:
def setup_logging():
    # Create logs directory if it doesn't exist
    log_dir = os.path.join(os.getcwd(), 'logs')
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Create a log file with timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    log_file = os.path.join(log_dir, f'training_log_{timestamp}.txt')

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file, encoding='utf-8'),
            logging.StreamHandler()  # This will also print to console
        ]
    )

    # Create logger
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)

    # Log the start of a new training session
    logger.info("="*50)
    logger.info("New Training Session Started")
    logger.info("="*50)
    logger.info(f"Log file created at: {log_file}")

    return logger

In [3]:
class Attention(nn.Module):
    def __init__(self, embedding_dims):
        super(Attention, self).__init__()
        self.embed_dim = embedding_dims
        self.bilinear = nn.Bilinear(self.embed_dim, self.embed_dim, 1)
        self.att1 = nn.Linear(self.embed_dim * 2, self.embed_dim)
        self.att2 = nn.Linear(self.embed_dim, self.embed_dim)
        self.att3 = nn.Linear(self.embed_dim, 1)
        self.softmax = nn.Softmax(0)

    def forward(self, node1, u_rep, num_neighs):
        uv_reps = u_rep.repeat(num_neighs, 1)
        x = torch.cat((node1, uv_reps), 1)
        x = F.relu(self.att1(x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.att2(x))
        x = F.dropout(x, training=self.training)
        x = self.att3(x)
        att = F.softmax(x, dim=0)
        return att

In [4]:
class Social_Encoder(nn.Module):

    def __init__(self, features, embed_dim, social_adj_lists, aggregator, base_model=None, cuda="cpu"):
        super(Social_Encoder, self).__init__()

        self.features = features
        self.social_adj_lists = social_adj_lists
        self.aggregator = aggregator
        if base_model != None:
            self.base_model = base_model
        self.embed_dim = embed_dim
        self.device = cuda
        self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim)  #

    def forward(self, nodes):

        to_neighs = []
        for node in nodes:
            # to_neighs.append(self.social_adj_lists[int(node)])
            to_neighs.append(self.social_adj_lists.get(int(node), set())) # Use .get() with default empty set for missing users
        neigh_feats = self.aggregator.forward(nodes, to_neighs)  # user-user network

        self_feats = self.features(torch.LongTensor(nodes.cpu().numpy())).to(self.device)
        self_feats = self_feats.t()

        # self-connection could be considered.
        combined = torch.cat([self_feats, neigh_feats], dim=1)
        combined = F.relu(self.linear1(combined))

        return combined

In [5]:
class Social_Aggregator(nn.Module):
    """
    Social Aggregator: for aggregating embeddings of social neighbors.
    """

    def __init__(self, features, u2e, embed_dim, cuda="cpu"):
        super(Social_Aggregator, self).__init__()

        self.features = features
        self.device = cuda
        self.u2e = u2e
        self.embed_dim = embed_dim
        self.att = Attention(self.embed_dim)

    def forward(self, nodes, to_neighs):
        embed_matrix = torch.empty(len(nodes), self.embed_dim, dtype=torch.float).to(self.device)
        for i in range(len(nodes)):
            tmp_adj = to_neighs[i]
            num_neighs = len(tmp_adj)
            e_u = self.u2e.weight[list(tmp_adj)] # fast: user embedding
            #slow: item-space user latent factor (item aggregation)
            #feature_neigbhors = self.features(torch.LongTensor(list(tmp_adj)).to(self.device))
            #e_u = torch.t(feature_neigbhors)

            u_rep = self.u2e.weight[nodes[i]]

            att_w = self.att(e_u, u_rep, num_neighs)
            att_history = torch.mm(e_u.t(), att_w).t()
            embed_matrix[i] = att_history
        to_feats = embed_matrix

        return to_feats

In [6]:
class UV_Aggregator(nn.Module):
    def __init__(self, v2e, r2e, u2e, history_t, lamda, embed_dim, cuda="cpu", uv=True):
        super(UV_Aggregator, self).__init__()
        self.uv = uv
        self.v2e = v2e
        self.r2e = r2e
        self.u2e = u2e
        self.device = cuda
        self.embed_dim = embed_dim
        self.w_r1 = nn.Linear(self.embed_dim * 2, self.embed_dim)
        self.w_r2 = nn.Linear(self.embed_dim, self.embed_dim)
        self.att = Attention(self.embed_dim)
        self.lamda = lamda
        self.history_t = history_t

    def forward(self, nodes, history_uv, history_r, history_userID):
        embed_matrix = torch.empty(len(history_uv), self.embed_dim, dtype=torch.float).to(self.device)
        for i in range(len(history_uv)):
            history = history_uv[i]
            num_histroy_item = len(history)
            tmp_label = history_r[i]
            user_id = history_userID[i]

            if self.uv == True:
                e_uv = self.v2e.weight[history]
                uv_rep = self.u2e.weight[nodes[i]]
            else:
                e_uv = self.u2e.weight[history]
                uv_rep = self.v2e.weight[nodes[i]]

            e_r = self.r2e.weight[tmp_label]
            x = torch.cat((e_uv, e_r), 1)
            x = F.relu(self.w_r1(x))
            o_history = F.relu(self.w_r2(x))

            att_w = self.att(o_history, uv_rep, num_histroy_item)

            if self.history_t is not None and self.uv == True:
                timestamps = self.history_t[user_id]
                timestamps = torch.tensor(timestamps, dtype=torch.float, device=att_w.device)
                max_time = timestamps.max()
                t_diff = max_time - timestamps

                # Add safety checks
                if torch.isnan(t_diff).any() or torch.isinf(t_diff).any():
                    print(f"Warning: Invalid t_diff values detected: {t_diff}")
                    t_diff = torch.clamp(t_diff, min=0.0, max=100.0)  # Clamp to reasonable range

                time_decay = torch.exp(-self.lamda * t_diff)

                # Check for invalid values in time_decay
                if torch.isnan(time_decay).any() or torch.isinf(time_decay).any():
                    print(f"Warning: Invalid time_decay values detected: {time_decay}")
                    time_decay = torch.clamp(time_decay, min=1e-10, max=1.0)  # Clamp to reasonable range

                att_w = att_w.squeeze() * time_decay

                # Check for zero sum
                att_sum = att_w.sum()
                if att_sum == 0:
                    print(f"Warning: Zero sum in attention weights: {att_w}")
                    att_w = torch.ones_like(att_w) / att_w.size(0)  # Use uniform distribution as fallback
                else:
                    att_w = att_w / att_sum

                att_w = att_w.unsqueeze(1)

            att_history = torch.mm(o_history.t(), att_w)
            att_history = att_history.t()
            embed_matrix[i] = att_history
        to_feats = embed_matrix
        return to_feats

In [7]:
class UV_Encoder(nn.Module):
    def __init__(self, features, embed_dim, history_uv_lists, history_r_lists, history_t, aggregator, cuda="cpu", uv=True):
        super(UV_Encoder, self).__init__()
        self.features = features
        self.uv = uv
        self.history_uv_lists = history_uv_lists
        self.history_r_lists = history_r_lists
        self.aggregator = aggregator
        self.embed_dim = embed_dim
        self.device = cuda
        self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim)
        self.history_t = history_t

    def forward(self, nodes):
        tmp_history_uv = []
        tmp_history_r = []
        tmp_userID = []
        for node in nodes:
            tmp_history_uv.append(self.history_uv_lists[int(node)])
            tmp_history_r.append(self.history_r_lists[int(node)])
            tmp_userID.append(int(node))

        if hasattr(self.aggregator, 'history_t'):
            self.aggregator.history_t = self.history_t

        neigh_feats = self.aggregator.forward(nodes, tmp_history_uv, tmp_history_r, tmp_userID)

        self_feats = self.features.weight[nodes]
        combined = torch.cat([self_feats, neigh_feats], dim=1)
        combined = F.relu(self.linear1(combined))

        return combined

# Main Code

> GraphRec Definition



In [8]:
class GraphRec(nn.Module):

    def __init__(self, enc_u, enc_v_history, r2e):
        super(GraphRec, self).__init__()
        self.enc_u = enc_u
        self.enc_v_history = enc_v_history
        self.embed_dim = enc_u.embed_dim

        self.w_ur1 = nn.Linear(self.embed_dim, self.embed_dim)
        self.w_ur2 = nn.Linear(self.embed_dim, self.embed_dim)
        self.w_vr1 = nn.Linear(self.embed_dim, self.embed_dim)
        self.w_vr2 = nn.Linear(self.embed_dim, self.embed_dim)
        self.w_uv1 = nn.Linear(self.embed_dim * 2, self.embed_dim)
        self.w_uv2 = nn.Linear(self.embed_dim, 16)
        self.w_uv3 = nn.Linear(16, 1)
        self.r2e = r2e
        self.bn1 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
        self.bn2 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
        self.bn3 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
        self.bn4 = nn.BatchNorm1d(16, momentum=0.5)
        self.criterion = nn.MSELoss()

    def forward(self, nodes_u, nodes_v):
        embeds_u = self.enc_u(nodes_u)
        embeds_v = self.enc_v_history(nodes_v)

        x_u = F.relu(self.bn1(self.w_ur1(embeds_u)))
        x_u = F.dropout(x_u, training=self.training)
        x_u = self.w_ur2(x_u)
        x_v = F.relu(self.bn2(self.w_vr1(embeds_v)))
        x_v = F.dropout(x_v, training=self.training)
        x_v = self.w_vr2(x_v)

        x_uv = torch.cat((x_u, x_v), 1)
        x = F.relu(self.bn3(self.w_uv1(x_uv)))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.bn4(self.w_uv2(x)))
        x = F.dropout(x, training=self.training)
        scores = self.w_uv3(x)
        return scores.squeeze()

    def loss(self, nodes_u, nodes_v, labels_list):
        scores = self.forward(nodes_u, nodes_v)
        return self.criterion(scores, labels_list)

In [9]:
def train(model, device, train_loader, optimizer, epoch, best_rmse, best_mae, logger):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        batch_nodes_u, batch_nodes_v, labels_list = data
        optimizer.zero_grad()
        loss = model.loss(batch_nodes_u.to(device), batch_nodes_v.to(device), labels_list.to(device))
        # loss.backward(retain_graph=True)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 0:
            log_message = '[%d, %5d] loss: %.3f, The best rmse/mae: %.6f / %.6f' % (
                epoch, i, running_loss / 100, best_rmse, best_mae)
            logger.info(log_message)
            running_loss = 0.0
    return 0

In [10]:
def test(model, device, test_loader, logger):
    model.eval()
    tmp_pred = []
    target = []
    with torch.no_grad():
        for test_u, test_v, tmp_target in test_loader:
            test_u, test_v, tmp_target = test_u.to(device), test_v.to(device), tmp_target.to(device)
            val_output = model.forward(test_u, test_v)
            tmp_pred.append(list(val_output.data.cpu().numpy()))
            target.append(list(tmp_target.data.cpu().numpy()))
    tmp_pred = np.array(sum(tmp_pred, []))
    target = np.array(sum(target, []))
    expected_rmse = sqrt(mean_squared_error(tmp_pred, target))
    mae = mean_absolute_error(tmp_pred, target)
    return expected_rmse, mae

In [11]:
def create_checkpoint_dir():
    checkpoint_dir = './checkpoints'
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    return checkpoint_dir

In [12]:
def save_checkpoint(model, optimizer, epoch, rmse, mae, checkpoint_dir, is_best=False, logger=None):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'rmse': rmse,
        'mae': mae
    }

    # Save regular checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint.pt')
    torch.save(checkpoint, checkpoint_path)

    # Save best model if it's the best so far
    if is_best:
        best_model_path = os.path.join(checkpoint_dir, 'best_model.pt')
        torch.save(checkpoint, best_model_path)
        log_message = f"New best model saved! RMSE: {rmse:.4f}, MAE: {mae:.4f}"
        if logger:
            logger.info(log_message)
        else:
            print(log_message)

In [17]:
# Setup logging
logger = setup_logging()
logger.info("Starting training process...")

# Define the arguments as a dictionary
args_dict = {
    '--batch_size': 128,
    '--embed_dim': 64,
    '--lr': 0.001,
    '--test_batch_size': 1000,
    '--epochs': 20,
}

# Create a list of strings from the dictionary
args_list = []
for k, v in args_dict.items():
    args_list.append(k)
    args_list.append(str(v))

parser = argparse.ArgumentParser(description='Social Recommendation: GraphRec model')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')
parser.add_argument('--embed_dim', type=int, default=64, metavar='N', help='embedding size')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate')
parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N', help='input batch size for testing')
parser.add_argument('--epochs', type=int, default=5, metavar='N', help='number of epochs to train')
parser.add_argument('--use_resume', action='store_true', help='resume from checkpoint')
args = parser.parse_args(args_list)

# Log training configuration
logger.info("Training Configuration:")
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Embedding Dimension: {args.embed_dim}")
logger.info(f"Learning Rate: {args.lr}")
logger.info(f"Test Batch Size: {args.test_batch_size}")
logger.info(f"Total Epochs: {args.epochs}")

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
use_cuda = False
if torch.cuda.is_available():
    use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
logger.info(f"Using device: {device}")

embed_dim = args.embed_dim
dir_data = './data/Epinions/pickle/dataset_Epinions_train80val10test10'

path_data = dir_data + ".pickle"
data_file = open(path_data, 'rb')
history_u_lists, history_ur_lists, history_v_lists, history_vr_lists, train_u, train_v, train_r, val_u, val_v, val_r, test_u, test_v, test_r, social_adj_lists, ratings_list, history_timestamp_lists = pickle.load(data_file)

# Validate and normalize timestamps
logger.info("Processing timestamps...")

# First, validate all timestamps
invalid_timestamps = False
for user_id, timestamps in history_timestamp_lists.items():
    if not timestamps:  # Skip empty lists
        continue
    if any(not isinstance(t, (int, float)) for t in timestamps):
        logger.warning(f"Invalid timestamp type found for user {user_id}")
        invalid_timestamps = True
    if any(t < 0 for t in timestamps):
        logger.warning(f"Negative timestamp found for user {user_id}")
        invalid_timestamps = True
    if any(torch.isnan(torch.tensor(t, dtype=torch.float)) for t in timestamps):
        logger.warning(f"NaN timestamp found for user {user_id}")
        invalid_timestamps = True

if invalid_timestamps:
    logger.warning("Found invalid timestamps. Attempting to clean data...")
    # Clean invalid timestamps
    for user_id in list(history_timestamp_lists.keys()):
        timestamps = history_timestamp_lists[user_id]
        valid_timestamps = [t for t in timestamps if isinstance(t, (int, float)) and t >= 0 and not torch.isnan(torch.tensor(t, dtype=torch.float))]
        if valid_timestamps:
            history_timestamp_lists[user_id] = valid_timestamps
        else:
            # If no valid timestamps, use a default value
            history_timestamp_lists[user_id] = [0.0] * len(history_u_lists[user_id])
            logger.warning(f"User {user_id} had no valid timestamps. Using default values.")

# Get global min and max timestamps
logger.info("Calculating global timestamp range...")
all_timestamps = []
for timestamps in history_timestamp_lists.values():
    all_timestamps.extend(timestamps)

global_min = min(all_timestamps)
global_max = max(all_timestamps)
logger.info(f"Global timestamp range: {global_min:.2f} to {global_max:.2f}")

# Normalize timestamps using global min and max
logger.info("Normalizing timestamps using global range...")
normalized_history_timestamp_lists = {}

for user_id, timestamps in history_timestamp_lists.items():
    if not timestamps:  # Skip empty lists
        normalized_history_timestamp_lists[user_id] = []
        continue

    # Normalize using global min and max
    if global_max > global_min:
        normalized_timestamps = [(t - global_min) / (global_max - global_min) for t in timestamps]
    else:
        # If all timestamps are the same, set to 0.5
        normalized_timestamps = [0.5] * len(timestamps)

    normalized_history_timestamp_lists[user_id] = normalized_timestamps

    # Log some statistics for verification
    # if user_id % 1000 == 0:  # Log every 1000th user
    #     logger.info(f"User {user_id} timestamp stats:")
    #     logger.info(f"  Original range: {min(timestamps):.2f} to {max(timestamps):.2f}")
    #     logger.info(f"  Normalized range: {min(normalized_timestamps):.2f} to {max(normalized_timestamps):.2f}")
    #     logger.info(f"  Number of timestamps: {len(normalized_timestamps)}")

# Replace original timestamps with normalized ones
history_timestamp_lists = normalized_history_timestamp_lists
logger.info("Timestamp normalization completed")

trainset = torch.utils.data.TensorDataset(torch.LongTensor(train_u), torch.LongTensor(train_v),
                                        torch.FloatTensor(train_r))
testset = torch.utils.data.TensorDataset(torch.LongTensor(val_u), torch.LongTensor(val_v),
                                        torch.FloatTensor(val_r))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=True)
num_users = history_u_lists.__len__()
num_items = history_v_lists.__len__()
num_ratings = ratings_list.__len__()

logger.info(f"Dataset loaded: {num_users} users, {num_items} items, {num_ratings} ratings")

u2e = nn.Embedding(num_users, embed_dim).to(device)
v2e = nn.Embedding(num_items, embed_dim).to(device)
r2e = nn.Embedding(num_ratings, embed_dim).to(device)
lamda = 1.0

# user feature
agg_u_history = UV_Aggregator(v2e, r2e, u2e, history_timestamp_lists, lamda, embed_dim, cuda=device, uv=True)
enc_u_history = UV_Encoder(u2e, embed_dim, history_u_lists, history_ur_lists, history_timestamp_lists, agg_u_history, cuda=device, uv=True)

# neighbors
agg_u_social = Social_Aggregator(lambda nodes: enc_u_history(nodes).t(), u2e, embed_dim, cuda=device)
enc_u = Social_Encoder(lambda nodes: enc_u_history(nodes).t(), embed_dim, social_adj_lists, agg_u_social,
                      base_model=enc_u_history, cuda=device)

# item feature
agg_v_history = UV_Aggregator(v2e, r2e, u2e, history_timestamp_lists, lamda, embed_dim, cuda=device, uv=False)
enc_v_history = UV_Encoder(v2e, embed_dim, history_v_lists, history_vr_lists, history_timestamp_lists, agg_v_history, cuda=device, uv=False)

# model
graphrec = GraphRec(enc_u, enc_v_history, r2e).to(device)
optimizer = torch.optim.RMSprop(graphrec.parameters(), lr=args.lr, alpha=0.9)

# Load checkpoint
checkpoint_path = './checkpoints/best_model.pt'
args.use_resume = True
if os.path.exists(checkpoint_path) and args.use_resume:
    logger.info(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    graphrec.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # Start from next epoch
    best_rmse = checkpoint['rmse']
    best_mae = checkpoint['mae']
    logger.info(f"Resuming from epoch {start_epoch}")
    logger.info(f"Previous best RMSE: {best_rmse:.4f}, MAE: {best_mae:.4f}")
else:
    logger.info("No checkpoint found, starting from scratch")
    start_epoch = 1
    best_rmse = 9999.0
    best_mae = 9999.0

# Create checkpoint directory
checkpoint_dir = create_checkpoint_dir()

endure_count = 0
logger.info(f'Total epochs: {args.epochs}')

for epoch in range(start_epoch, args.epochs + 1):
    train(graphrec, device, train_loader, optimizer, epoch, best_rmse, best_mae, logger)
    expected_rmse, mae = test(graphrec, device, test_loader, logger)

    # Early stopping
    if best_rmse > expected_rmse:
        best_rmse = expected_rmse
        best_mae = mae
        endure_count = 0
        # Save best model
        save_checkpoint(graphrec, optimizer, epoch, expected_rmse, mae, checkpoint_dir, is_best=True, logger=logger)
    else:
        endure_count += 1
    logger.info(f"Epoch {epoch} - RMSE: {expected_rmse:.4f}, MAE: {mae:.4f}")

    if endure_count > 5:
        logger.info("Early stopping triggered")
        break

logger.info("Training completed!")

2025-05-21 21:12:45,059 - INFO - New Training Session Started
2025-05-21 21:12:45,060 - INFO - Log file created at: d:\workspace\GraphRec\logs\training_log_20250521_211245.txt
2025-05-21 21:12:45,061 - INFO - Starting training process...
2025-05-21 21:12:45,063 - INFO - Training Configuration:
2025-05-21 21:12:45,063 - INFO - Batch Size: 128
2025-05-21 21:12:45,064 - INFO - Embedding Dimension: 64
2025-05-21 21:12:45,064 - INFO - Learning Rate: 0.001
2025-05-21 21:12:45,064 - INFO - Test Batch Size: 1000
2025-05-21 21:12:45,065 - INFO - Total Epochs: 20
2025-05-21 21:12:45,065 - INFO - Using device: cuda
2025-05-21 21:12:46,307 - INFO - Processing timestamps...


KeyboardInterrupt: 

# 結果如上