In [None]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from algorithms import gmf2

In [None]:
n = 25     # number of nodes
p = 0.1    # edge probability
dim = 8

is_symmetric = False
n_iters = 1000
plot = True

In [None]:
nx_G = nx.erdos_renyi_graph(n, p, seed=0)
S = np.array(nx.to_numpy_matrix(nx_G))
for i in range(n):
    for j in range(n):
        if S[i,j] == 1:
            S[i,j] = 5
        else:
            S[i,j] = -5

In [None]:
Nij_p_np = np.exp(S)
Nij_n_np = np.ones((n, n))
model = gmf2.GMF(Nij_p_np, Nij_n_np, embed_dim=dim, is_symmetric=is_symmetric, n_iters=n_iters, plot=plot)
U1 = model.V.numpy()
V1 = model.W.numpy()

In [None]:
u, s, vh = np.linalg.svd(S)
u = u @ np.diag(s**0.5)
vh = np.diag(s**0.5) @ vh
v = vh.T
U2 = u[:, :dim] 
V2 = v[:, :dim]

In [None]:
Nij_p_np = np.exp(S)
Nij_n_np = np.ones((n, n))
model = gmf2.GMF(Nij_p_np, Nij_n_np, embed_dim=5, is_symmetric=is_symmetric, n_iters=n_iters, plot=plot)
U3 = model.V.numpy()
V3 = model.W.numpy()

In [None]:
Nij_p_np = np.exp(S)
Nij_n_np = np.ones((n, n))
model = gmf2.GMF(Nij_p_np, Nij_n_np, embed_dim=15, is_symmetric=is_symmetric, n_iters=n_iters, plot=plot)
U4 = model.V.numpy()
V4 = model.W.numpy()

In [None]:
Nij_p_np = np.exp(S)
Nij_n_np = np.ones((n, n))
model = gmf2.GMF(Nij_p_np, Nij_n_np, embed_dim=25, is_symmetric=is_symmetric, n_iters=n_iters, plot=plot)
U5 = model.V.numpy()
V5 = model.W.numpy()

In [None]:
_min = -5
_max = 5

f, axes = plt.subplots(2, 3, sharey=True, figsize=(12, 12))

im0 = axes[0,0].imshow(S, cmap="coolwarm", vmin = _min, vmax = _max)
axes[0,0].set_title(r"Original", fontsize=16)
axes[0,0].xaxis.set_tick_params(labelsize=12)
axes[0,0].yaxis.set_tick_params(labelsize=12)

im1 = axes[0,1].imshow(U1 @ V1.T, cmap="coolwarm", vmin = _min, vmax = _max)
axes[0,1].set_title(r"The proposed (d=8)",  fontsize=16)
axes[0,1].xaxis.set_tick_params(labelsize=12)
axes[0,1].yaxis.set_tick_params(labelsize=12)

im2 = axes[0,2].imshow(U2 @ V2.T, cmap="coolwarm", vmin = _min, vmax = _max)
axes[0,2].set_title(r"Truncated SVD",  fontsize=16)
axes[0,2].xaxis.set_tick_params(labelsize=12)
axes[0,2].yaxis.set_tick_params(labelsize=12)

im3 = axes[1,0].imshow(U3 @ V3.T, cmap="coolwarm", vmin = _min, vmax = _max)
axes[1,0].set_title(r"The proposed (d=5)",  fontsize=16)
axes[1,0].xaxis.set_tick_params(labelsize=12)
axes[1,0].yaxis.set_tick_params(labelsize=12)

im4 = axes[1,1].imshow(U4 @ V4.T, cmap="coolwarm", vmin = _min, vmax = _max)
axes[1,1].set_title(r"The proposed (d=15)",  fontsize=16)
axes[1,1].xaxis.set_tick_params(labelsize=12)
axes[1,1].yaxis.set_tick_params(labelsize=12)

im5 = axes[1,2].imshow(U5 @ V5.T, cmap="coolwarm", vmin = _min, vmax = _max)
axes[1,2].set_title(r"The proposed (d=25)",  fontsize=16)
axes[1,2].xaxis.set_tick_params(labelsize=12)
axes[1,2].yaxis.set_tick_params(labelsize=12)

ticks = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
cbar = f.colorbar(im0, ax=axes, orientation='horizontal', anchor=(0,2), aspect=32, ticks=ticks)
cbar.ax.tick_params(labelsize=14)

# plt.savefig('example.png')

plt.show()