In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("white")
plt.rcParams["figure.figsize"] = (20, 20)

import torch
import pickle
import wikipedia
import numpy as np
import pandas as pd
import networkx as nx
from umap import UMAP
from tqdm import tqdm_notebook as tqdm
from sklearn.cluster import AgglomerativeClustering
from scipy.sparse.linalg import svds

In [None]:
with open("/mnt/efs/wikipedia/good_article_links.pkl", "rb") as fp:
    graph_dict = pickle.load(fp)
    G = nx.from_dict_of_lists(graph_dict)  # .to_undirected()
    node_names = list(G.nodes)

In [None]:
graph_dict

In [None]:
adjacency_matrix = nx.adjacency_matrix(G).todense()

df = pd.DataFrame(
    adjacency_matrix,
    index=node_names,
    columns=node_names,
).fillna(0)

In [None]:
adjacency_matrix

In [None]:
user_means = df.mean(axis=1)
meaned_df = df.sub(user_means, axis=0)

In [None]:
U, S, Vt = svds(meaned_df, k=30)

In [None]:
US = np.dot(U, np.diag(S))
predictions = np.dot(US, Vt) + np.array(user_means).reshape(-1, 1)

predictions = pd.DataFrame(np.round(predictions), index=node_names, columns=node_names)

In [None]:
US.shape

# distance matrix

In [None]:
from scipy.spatial.distance import cdist

In [None]:
distance_matrix = pd.DataFrame(
    cdist(Vt.T, Vt.T, metric="cosine"), index=node_names, columns=node_names
)

# individual distances

In [None]:
distance_matrix["Apple"].sort_values()

# plot embeddings

In [None]:
embeddings_2d = UMAP(n_components=2).fit_transform(U)

In [None]:
df = pd.DataFrame(embeddings_2d)
df.index = node_names

In [None]:
cluster = AgglomerativeClustering(300)
df["cluster"] = cluster.fit_predict(embeddings_2d)

In [None]:
df.plot.scatter(x=0, y=1, c=df["cluster"], cmap="Paired");

In [None]:
for selected_cluster in range(df["cluster"].max() + 1):
    df["selected_cluster"] = df["cluster"] == selected_cluster
    print(
        np.random.choice(
            df.index.values[df["cluster"] == selected_cluster], size=10, replace=False
        ),
        "\n\n",
    )