In [1]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
import torch

In [2]:
data_dir = "../ml-1m/processed"
df_users = pd.read_csv(f"{data_dir}/users.csv", sep="\t")
df_movies = pd.read_csv(f"{data_dir}/movies.csv", sep="\t")

X = np.load(f"{data_dir}/X.npy")
Y = np.load(f"{data_dir}/Y.npy")
R = np.load(f"{data_dir}/R.npy")

n_users, n_movies = Y.shape
d_users = X.shape[1]

train_users = 5000
users_perm = npr.permutation(n_users)

X_trn = X[:train_users]
Y_trn = Y[:train_users]
X_tst = X[train_users:]
Y_tst = Y[train_users:]

In [3]:
n_clusters = 16
kmeans = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10).fit(Y_trn.T)

In [4]:
vs = []
smoother = 0.001
for k in range(n_clusters):
    v_k = np.zeros(n_movies) + smoother
    v_k[kmeans.labels_ == k] += 1
    v_k = v_k / v_k.sum()
    vs.append(v_k.reshape((n_movies, 1)))

V = np.hstack(vs)

print("\nV.shape =", V.shape, "\n\n\n=== Clusters ===\n")
for k in range(n_clusters):
    print(f"cluster #{k}\n")
    print(df_movies.iloc[np.argsort(-V[:,k])[:5]].title, "\n")


V.shape = (3883, 16) 


=== Clusters ===

cluster #0

2899    Time Bandits (1981)
3569       Moonraker (1979)
2242            2010 (1984)
1255      Highlander (1986)
2219      Thing, The (1982)
Name: title, dtype: object 

cluster #1

3882                                Contender, The (2000)
1598               I Know What You Did Last Summer (1997)
467                           Hudsucker Proxy, The (1994)
2378                                 Varsity Blues (1999)
464     Englishman Who Went Up a Hill, But Came Down a...
Name: title, dtype: object 

cluster #2

3352                    Animal House (1978)
1023                        Die Hard (1988)
1064    Monty Python's Life of Brian (1979)
1202             Blues Brothers, The (1980)
1892                        Rain Man (1988)
Name: title, dtype: object 

cluster #3

1305    Amityville: A New Generation (1993)
3103                Ulysses (Ulisse) (1954)
1633           Chairman of the Board (1998)
1631                   Critical Care (19

In [5]:
V.T.shape, Y_trn.shape

((16, 3883), (5000, 3883))

In [6]:
def log_loss(user_hist, cluster_dist):
    return -np.dot(user_hist, np.log(cluster_dist))

def get_best_clusters(Y, V):
    assignments = []
    for i in range(Y.shape[0]):
        best_loss = np.inf
        best_cluster = None
        y_i = Y[i]
        for k in range(V.shape[1]):
            v_k = V[:,k]
            ll = log_loss(y_i, v_k)
            if ll < best_loss:
                best_loss = ll
                best_cluster = k
        assignments.append(best_cluster)
    return np.array(assignments)

In [7]:
cluster_assignments = get_best_clusters(Y_trn, V)

In [8]:
lr = LogisticRegression(max_iter=1000)

In [9]:
lr.fit(X_trn, cluster_assignments)

In [10]:
(lr.predict(X_trn) == 15).sum()

0

In [11]:
cluster_assignments[:50]

array([ 5, 14,  5,  9,  5, 14,  9, 14,  5, 14, 14, 11,  5,  4, 14,  5,  5,
        5, 14,  9,  5, 14,  5,  5,  5, 14, 14, 14,  9,  4, 14,  4, 14, 14,
        5, 14, 14, 14,  8,  5,  9, 14,  5,  5, 14,  5,  4, 14,  5,  5])

In [12]:
lr.predict_proba(X_trn)[:5]

array([[7.45620149e-04, 4.44042637e-04, 1.04789915e-03, 4.55846287e-02,
        2.94936760e-01, 7.29077288e-04, 1.10303867e-02, 5.16989454e-04,
        2.92358274e-02, 2.91863156e-04, 3.57853064e-03, 1.11389229e-02,
        7.63838903e-05, 6.00306118e-01, 3.36949931e-04],
       [9.00167453e-03, 1.07724054e-03, 7.72515181e-03, 5.45599223e-02,
        3.08171864e-01, 6.94887523e-03, 8.74612329e-03, 2.41443512e-03,
        1.40041653e-02, 1.71163389e-03, 3.06837099e-02, 1.91230925e-03,
        2.86794064e-04, 5.16981795e-01, 3.57743053e-02],
       [3.70797956e-03, 1.46726486e-04, 6.71289780e-03, 7.24978622e-02,
        4.32840609e-01, 4.46041578e-04, 1.70789035e-03, 3.37300315e-03,
        1.20555738e-01, 5.83599778e-04, 2.68415458e-03, 2.40037374e-03,
        1.07064205e-04, 3.51709794e-01, 5.26266028e-04],
       [1.27039297e-02, 7.44162402e-04, 5.26353222e-03, 1.76264079e-01,
        2.82067058e-01, 2.07517389e-03, 2.07229281e-02, 4.30402257e-03,
        7.78798956e-02, 1.14303160e-0