In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from model import mf, mfDataset, compute_rating

In [2]:
file_path = './ml-1m/ratings.dat'
batch_size = 1
device = torch.device('cuda:0')
learning_rate = 1e-2
weight_decay = 1e-5
epochs = 10
embedding_size = 100
loss_func = torch.nn.MSELoss().to(device)

In [3]:
df = pd.read_csv(file_path, header=None, delimiter='::')
x, y = df.iloc[:, :2], df.iloc[:, 2]

x_train, x_val_test, y_train, y_val_test = train_test_split(x, y, test_size=0.1)
x_val, x_test, y_val, y_test = train_test_split(x_val_test, y_val_test, test_size=0.5)

train_dataset = mfDataset(np.array(x_train[0]), np.array(
        x_train[1]),  np.array(y_train).astype(np.float32))
val_dataset = mfDataset(np.array(x_val[0]), np.array(
        x_val[1]),  np.array(y_val).astype(np.float32))
test_dataset = mfDataset(np.array(x_test[0]), np.array(
        x_test[1]), np.array(y_test).astype(np.float32))

train_DataLoader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_DataLoader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_DataLoader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

mean_rating = df.iloc[:, 2].mean()
num_users = max(df[0])+1
num_items = max(df[1])+1
print(f"num_users:{num_users-1}")
print(f"num_items:{num_items-1}")

  return func(*args, **kwargs)


num_users:6040
num_items:3952


In [5]:
mean_rating

3.581564453029317

In [4]:
# create user & item latent factor vectors
user_emb = torch.empty(num_users, embedding_size).to(device)
user_bias = torch.empty(num_users, 1).to(device)
item_emb = torch.empty(num_items, embedding_size).to(device)
item_bias = torch.empty(num_items, 1).to(device)

nn.init.normal_(user_emb, mean=0, std=0.1)
nn.init.normal_(user_bias, mean=0, std=0.1)
nn.init.normal_(item_emb, mean=0, std=0.1)
nn.init.normal_(item_bias, mean=0, std=0.1)

tensor([[-0.0009],
        [ 0.0775],
        [ 0.0659],
        ...,
        [-0.0332],
        [-0.0676],
        [ 0.0711]], device='cuda:0')

In [24]:
index=torch.tensor([1])
U=user_emb[index]
I=item_emb[index]
U_b=user_bias[index]
I_b=item_bias[index]
((U*I).sum(1)+U_b+I_b+mean_rating).squeeze(0)

tensor([3.8086], device='cuda:0')

In [8]:
for x_u, x_i, y in train_DataLoader:
    print(x_u)
    print(x_i)
    print(y)
    break

tensor([2063])
tensor([597])
tensor([3.])


In [7]:
# training & val
for epoch in range(epochs):
    # training phase
    total_loss, total_len = 0, 0
    for x_u, x_i, y in train_DataLoader:
        x_u, x_i, y = x_u.to(device), x_i.to(device), y.to(device)
        y_pre = compute_rating(user_emb[x_u], item_emb[x_i], user_bias[x_u], item_bias[x_i], mean_rating).to(device)
        loss = loss_func(y, y_pre)
        e_ui=y-y_pre

        # gradient descent
        user_emb[x_u] -= learning_rate * 2*(-e_ui*item_emb[x_i] + weight_decay*user_emb[x_u])
        item_emb[x_i] -= learning_rate * 2*(-e_ui*user_emb[x_u] + weight_decay*item_emb[x_i])
        user_bias[x_u] -= learning_rate * 2*(-e_ui* + weight_decay*user_bias[x_u])
        item_bias[x_i] -= learning_rate * 2*(-e_ui* + weight_decay*item_bias[x_i])

        total_loss += loss.item()*len(y)
        total_len += len(y)
    train_loss = total_loss/total_len
    
    # val phase
    labels, predicts = [], []
    for x_u, x_i, y in val_DataLoader:
        x_u, x_i, y = x_u.to(device), x_i.to(device), y.to(device)
        y_pre = compute_rating(user_emb[x_u], item_emb[x_i], user_bias[x_u], item_bias[x_i], mean_rating)
        labels.extend(y.tolist())
        predicts.extend(y_pre.tolist())
    mse = mean_squared_error(np.array(labels), np.array(predicts))

    print("epoch {}, train loss is {}, val mse is {}".format(
        epoch, train_loss, mse))

  return F.mse_loss(input, target, reduction=self.reduction)


KeyboardInterrupt: 

In [23]:
a=torch.empty(1,1, dtype=torch.float32)
print(a)
a.squeeze(0)

tensor([[4.4842e-44]])


tensor([4.4842e-44])