## データセット作成

In [1]:
import numpy as np
import pandas as pd

def make_data(arr):
    arr = np.where(arr<0, 0, arr)
    arr = np.where(arr>10, 10, arr)
    res = np.zeros((arr.shape[0], arr.shape[1]))
    for i in range(arr.shape[0]):
        rand = np.random.randint(10, 31)
        rand_index = np.random.choice(arr.shape[1], rand, replace=False)
        for j in rand_index:
            res[i][j] = arr[i][j]
    return res

# 正規分布(平均=5, 分散=2)
rate_normal = np.round(np.random.normal(5, 2, (100, 1000)))
rate_normal = make_data(rate_normal)

## 評価予測

In [2]:
np.random.seed(seed=0)

# 更新式の誤差
def get_rating_error(r, p, q):
    rating_error = r - np.dot(p, q)
    return rating_error

# 損失関数
def get_error(R, P, Q, beta):
    error = 0
    for i in range(len(R)):
        for j in range(len(R[i])):
            if R[i][j] == 0:
                continue
            error += pow(get_rating_error(R[i][j], P[:,i], Q[:,j]), 2)
    error += beta/2 * (np.linalg.norm(P) + np.linalg.norm(Q))
    return error

# R=近似したい行列、K=次元数
def matrix_factorization(R, K, steps=1000, alpha=0.0005, beta=0.1, threshold=100):
    P = np.random.rand(K, len(R))
    Q = np.random.rand(K, len(R[0]))
    for step in range(steps):
        for i in range(len(R)):
            for j in range(len(R[i])):
                if R[i][j] == 0:
                    continue
                err = get_rating_error(R[i][j], P[:, i], Q[:, j])
                for k in range(K):
                    P[k][i] += alpha * (2 * err * Q[k][j])
                    Q[k][j] += alpha * (2 * err * P[k][i])
        error = get_error(R, P, Q, beta)
        if step%10 == 0:
            print('Step{}  error : {}'.format(step, error))
        if error < threshold:
            print('Step{}  error : {}'.format(step, error))
            print('学習終了')
            break
    return P, Q

P, Q = matrix_factorization(rate_normal, 10)

Step0  error : 19456.898212298893
Step10  error : 9985.105908030713
Step20  error : 6757.651822472697
Step30  error : 5443.872754502462
Step40  error : 4722.521391669171
Step50  error : 4236.239846654498
Step60  error : 3868.6205396308574
Step70  error : 3569.8451096285057
Step80  error : 3313.9800108767076
Step90  error : 3086.02030529994
Step100  error : 2876.8545959365765
Step110  error : 2680.8934435516126
Step120  error : 2494.7910735853197
Step130  error : 2316.682311827278
Step140  error : 2145.6867894006064
Step150  error : 1981.5638569516045
Step160  error : 1824.4609770137456
Step170  error : 1674.7269554771265
Step180  error : 1532.7750881397728
Step190  error : 1398.987314758191
Step200  error : 1273.6526083524327
Step210  error : 1156.9333101876537
Step220  error : 1048.853209036127
Step230  error : 949.3014403545119
Step240  error : 858.0468909788351
Step250  error : 774.7586549161642
Step260  error : 699.0290457480907
Step270  error : 630.396603202809
Step280  error : 56