In [1]:
import sys
import numpy as np
from scipy import sparse
from sklearn.metrics import mean_squared_error

In [4]:
def load_data(path='data/'):
    train = sparse.load_npz(path + "train.npz")
    test = sparse.load_npz(path + "test.npz")
    test_X = np.c_[test.tocoo().row, test.tocoo().col]
    test_y = test.tocoo().data
    return train.toarray(), test_X, test_y


def J(X, A, U, V, lbd):
    return 0.5 * np.linalg.norm(A * (X - U @ V.T), ord='fro') ** 2 \
           + lbd * np.linalg.norm(U, ord='fro') ** 2 \
           + lbd * np.linalg.norm(V, ord='fro') ** 2


def rmse(U, V, test_X, test_y):
    X_pred = U @ V.T
    pred_y = X_pred[test_X[:,0].reshape(-1), test_X[:,1].reshape(-1)]
    return np.sqrt(mean_squared_error(test_y, pred_y))


def decomposition(X, test_X, test_y, k=50, lbd=1e-2, learning_rate=1e-3, print_every=1):
    A = X.astype('bool').astype('int')
    U = np.random.uniform(-1e-2, 1e-2, (X.shape[0], k))
    V = np.random.uniform(-1e-2, 1e-2, (X.shape[1], k))
    delta_loss = np.inf
    old_loss = np.inf
    i = 0
    loss_array = []
    rmse_array = []
    while delta_loss >= 100:
        dU = (A * (U @ V.T - X)) @ V + 2 * lbd * U
        dV = (A * (U @ V.T - X)).T @ U + 2 * lbd * V
        new_U = U - learning_rate * dU
        new_V = V - learning_rate * dV
        loss = J(X, A, new_U, new_V, lbd)
        new_delta_loss = old_loss - loss
        if new_delta_loss < 0:
            learning_rate *= 0.2
            continue
        learning_rate = min(1e-3, learning_rate / 0.2)
        U = new_U
        V = new_V
        old_loss = loss
        delta_loss = new_delta_loss 
        loss_array.append(loss)
        rmse_array.append(rmse(U, V, test_X, test_y))
        
        i += 1
        if i % print_every == 0:
            print("#%d: loss=%f, rmse=%f" % (i, loss_array[-1], rmse_array[-1]))
            sys.stdout.flush()
    return loss_array, rmse_array

In [3]:
A, test_X, test_y = load_data("data/")

In [5]:
loss_array, rmse_array = decomposition(A, test_X, test_y, print_every=50)

#50: loss=3021308.536056, rmse=0.939372
#100: loss=2612751.213571, rmse=0.875938
#150: loss=2434943.510018, rmse=0.848873
#200: loss=2321396.384582, rmse=0.832441
