In [None]:
from __future__ import print_function
from pymongo import MongoClient
from tqdm import tqdm
from collections import Counter
from collections import defaultdict
import numpy as np

In [None]:
%%latex
\begin{align}
&\min \sum_{j} \Big( \sum_{i, t_i} ( y_{i,j,t_i} - \mathbf{u}_{i,t_i}^T \mathbf{v}_{j})^2 + C_u ||\mathbf{u}_{i,t_i}-\mathbf{u}_{i,t_i-1}||^2 \Big)+ C_v ||\mathbf{v}_{j}||^2  \\
&\text{where}\quad \mathbf{u}_{i,0} = 0
\end{align}

In [None]:
# parameters
n_max = 1000
m_max = 100
C_u = 0.01
C_v = 0.1
T = 5
D = 3

In [None]:
client = MongoClient()
db = client.instacart

In [None]:
user_lst = []
prod_lst = []
for doc in db.orders.find().limit(n_max):
    user_lst.append(doc["user_id"])
    prod_lst += [x for y in doc["orders"] for x in y["products"]]
prod_cnt = Counter(prod_lst)

In [None]:
user_set = {x for x in user_lst}
prod_set = {x for x, cnt in prod_cnt.most_common(m_max)}

In [None]:
Y = {} # Y[i][t][j]
Z = {} # Z[j][i][t]
for doc in db.orders.find().limit(n_max):
    i = doc["user_id"]
    orders = doc["orders"]
    orders = [x for x in orders if len(prod_set.intersection(x["products"]))>0]
    if len(orders) < 2:
        user_set.discard(i)
        continue
    t_end = max(1, len(orders)-T-1)
    for t in range(t_end):
        p = Counter([x for y in orders[t:t+T] for x in y["products"]
                    if x in prod_set])
        for j, y_ijt_unnorm in p.items():
            y_ijt = (y_ijt_unnorm+0.0)/T
            if i not in Y:
                Y[i] = {}
            if t not in Y[i]:
                Y[i][t] = {}
            if j not in Z:
                Z[j] = {}
            if i not in Z[j]:
                Z[j][i] = {}
            Y[i][t][j] = y_ijt
            Z[j][i][t] = y_ijt

In [None]:
U = {} # U[i][t]
V = {} # V[j]
np.random.seed(1)
for i, y_i in Y.items():
    U[i] = {}
    for t, y_it in y_i.items():
        U[i][t] = np.random.rand(D)/D
for j in prod_set:
    V[j] = np.random.rand(D)/D

In [None]:
def update_u_it(y_i, u_i, V, t):
    u_curr = u_i[t]
    u_prev = np.zeros(D)
    if t > 0:
        u_prev = u_i[t-1]
    if t not in y_i:
        print(y_i, u_i, t)
    j_seq = [j for j in y_i[t].keys()]
    n_it = len(j_seq)
    y_it = np.zeros((n_it, 1))
    v_it = np.zeros((n_it, D))
    I = np.diag(np.ones(D))
    for j_idx, j in enumerate(j_seq):
        y_it[j_idx,:] = y_i[t][j]
        v_it[j_idx,:] = V[j]  
    out1 = np.linalg.inv(np.dot(v_it.T, v_it) + C_u * I)
    out2 = np.reshape(np.dot(v_it.T, y_it), -1) + C_u * u_prev
    out = np.dot(out1, out2)
    return out
    
def update_u(Y, U, V):
    for i, u_i in U.items():
        t_max = len(u_i.keys())
        for t in range(t_max):
            U[i][t] = update_u_it(Y[i], u_i, V, t)
            
def update_v(Z, U, V):
    I = np.diag(np.ones(D))
    for j in V.keys():
        v_j = V[j]
        y_j = []
        for i, z_ji in Z[j].items():
            for t, z_jit in z_ji.items():
                y_j.append(z_jit)
        n_j = len(y_j)
        y_j = np.reshape(np.array(y_j), (n_j, 1))                    
        u_j = np.zeros((n_j, D))
        u_idx = 0
        for i, z_ji in Z[j].items():
            for t, z_jit in z_ji.items():
                u_j[u_idx,:] = U[i][t]
                u_idx += 1
        out1 = np.linalg.inv(np.dot(u_j.T, u_j)+C_v*I)
        out2 = np.reshape(np.dot(u_j.T, y_j), -1)
        V[j] = np.dot(out1, out2)
        
def loss_function(Y, U, V):
    loss = 0
    for i, y_i in Y.items():
        for t, y_it in y_i.items():
            u_prev = np.zeros(D)
            if t > 0:
                u_prev = U[i][t-1]
            if t not in U[i]:
                print(U[i], t, y_i)
            dU = (U[i][t] - u_prev)
            dU_norm = np.dot(dU.T, dU)
            loss += C_u * dU_norm
            for j, y_itj in y_it.items():
                loss += (y_itj - np.dot(U[i][t].T, V[j]))**2
    for j, v_j in V.items():
        loss += C_v * np.dot(v_j.T, v_j)
    return loss

In [None]:
num_iteration = 100
for k in range(num_iteration):
    update_u(Y, U, V)
    update_v(Z, U, V)
    print(k, loss_function(Y, U, V))

In [None]:
U