In [None]:
%load_ext autoreload
%autoreload 2
import os
import random
import numpy as np
import tensorflow as tf

import keras.backend as K
from keras.layers import Input, Dense, Flatten
from keras.models import Model
from keras.callbacks import EarlyStopping
from keras.optimizers import Adam
from keras.regularizers import l2
from sklearn.model_selection import train_test_split

from spektral.layers import GraphConv, GlobalAvgPool, EdgeConditionedConv

from spektral.utils import Batch, batch_iterator
from spektral.utils import label_to_one_hot, normalized_laplacian
from spektral.layers.ops import sp_matrix_to_sp_tensor

import graph

In [None]:
# set random seed
SEED = 2020
os.environ['PYTHONHASHSEED']=str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_random_seed(SEED)

session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)

In [None]:
import networkx as nx
import scipy.sparse as sp

In [None]:
from spektral.datasets import mnist

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Load data
X_train, y_train, X_val, y_val, X_test, y_test, _ = mnist.load_data()
X_train, X_val, X_test = X_train[..., None], X_val[..., None], X_test[..., None]
N = X_train.shape[-2]      # Number of nodes in the graphs
F = X_train.shape[-1]      # Node features dimensionality
n_out = 10  # Dimension of the target
print(X_train.shape, y_train.shape)

In [None]:
def grid_graph(m, k=8, corners=False):
    '''
    To create adjacency matrix as per Defferrard et al. 2016
    '''
    z = graph.grid(m)
    dist, idx = graph.distance_sklearn_metrics(z, k=k, metric='euclidean')
    A = graph.adjacency(dist, idx)

    # Connections are only vertical or horizontal on the grid.
    # Corner vertices are connected to 2 neightbors only.
    if corners:
        import scipy.sparse
        A = A.toarray()
        A[A < A.max()/1.5] = 0
        A = scipy.sparse.csr_matrix(A)
        print('{} edges'.format(A.nnz))

    print("{} > {} edges".format(A.nnz//2, k*m**2//2))
    return A


def draw_graph(A, m=28, ax=None, spring_layout=False, size_factor=10):
    '''Draw the graph given adjacency matrix(A),
    optionally with spring_layout.
    '''
    assert m**2 == A.shape[0] == A.shape[1]
    # Create the nx.Graph object
    G = nx.from_scipy_sparse_matrix(A)
    print('Number of nodes: %d; Number of edges: %d' % \
          (G.number_of_nodes(), G.number_of_edges()))
    grid_coords = graph.grid(m)

    if spring_layout:
        # remove nodes without edges
        nodes_without_edges = [n for n, k in  G.degree() if k == 0]
        G.remove_nodes_from(nodes_without_edges)
        print('After removing nodes without edges:')
        print('Number of nodes: %d; Number of edges: %d' % \
              (G.number_of_nodes(), G.number_of_edges()))
    
    z = graph.grid(m)
    
    # initial positions
    pos = {n: z[n] for n in G.nodes()} 
    
    if spring_layout:
        pos = nx.spring_layout(G, 
                               pos=pos,
                               iterations=200)
    
    ax = nx.draw(G, pos,
                 node_size=[G.degree(n) * size_factor for n in G.nodes()],
                 ax=ax
                )
    return ax

In [None]:
A = grid_graph(28, k=8)
plt.imshow(A.todense())

In [None]:
# visualize the graph
fig, ax = plt.subplots(figsize=(8, 8))
ax = draw_graph(A, ax=ax, size_factor=1)

In [None]:
ax = draw_graph(A, ax=ax, size_factor=1, spring_layout=True)

# Feature graph as a 2D Euclidean grid in the entire space

In [None]:
fig, axes = plt.subplots(figsize=(20, 5), ncols=4)

axes[0].imshow(A.todense())
axes[0].set_title('$A$')

# degree matrix D
D = A.sum(axis=0).reshape(28, 28)
axes[1].imshow(D)
axes[1].set_title('$D$')

axes[2] = draw_graph(A, ax=axes[2], size_factor=1)
axes[3] = draw_graph(A, ax=axes[3], size_factor=1, spring_layout=True)

fig.tight_layout()

# Feature graphs as a "pruned" grid for each digit

In [None]:
# threshold = 0.25 # to reduce the noise for averaged signals
threshold = 0.5
d_digit_graphs = {} # to collect feature graphs from each class

for i in range(10):
    mask = y_train == i
    
    
    fig, axes = plt.subplots(figsize=(20, 5), ncols=4)

    x_train_i_avg = X_train[mask].mean(axis=0).flatten()
    axes[0].imshow(x_train_i_avg.reshape(28, 28))

    # threshold the averages of pixels
    x_train_i_avg[x_train_i_avg < threshold] = 0
    axes[1].imshow(x_train_i_avg.reshape(28, 28))

    # a sparse diag matrix with the intensities values on the diagnoal
    A_diag_i = sp.diags(x_train_i_avg, dtype=np.float32).tolil()

    # "prune" the adjacency of the grid graph to preserve the subgraph with the data
    A_i = A.dot(A_diag_i)
    d_digit_graphs[i] = A_i
    
    axes[2] = draw_graph(A_i, ax=axes[2], size_factor=1)
    
    axes[3] = draw_graph(A_i, ax=axes[3], size_factor=1, spring_layout=True)
    fig.tight_layout()
    plt.show()

# Graph convolutional network for classification with different feature graphs

In [None]:
# Parameters
l2_reg = 5e-4         # Regularization rate for l2
learning_rate = 0.03  # Learning rate for SGD
batch_size = 100       # Batch size
epochs = 20         # Number of training epochs
# es_patience = 10     # Patience fot early stopping

In [None]:
from keras.layers import MaxPooling2D, Reshape
from spektral.layers import GraphConv, ChebConv

In [None]:
def GCN_single_layer(A, N=28*28, F=1,
                     n_out=10,
                     l2_reg=l2_reg, 
                     learning_rate=learning_rate,
                    ):
    # Computes a normalized Laplacian (as the conv filter)
    L = normalized_laplacian(A)
    
    # Model definition
    # N: Number of nodes in the graphs
    # F: Node features dimensionality
    X_in = Input(shape=(N, F))
    # Pass A as a fixed tensor, otherwise Keras will complain about inputs of
    # different rank.
    A_in = Input(tensor=sp_matrix_to_sp_tensor(L))

#     graph_conv = GraphConv(10,
#                            activation='relu',
#                            kernel_regularizer=l2(l2_reg),
#                            use_bias=True)([X_in, A_in])
    graph_conv = ChebConv(10,
                           activation='relu',
                           kernel_regularizer=l2(l2_reg),
                           use_bias=True)([X_in, A_in])

    fc = Flatten()(graph_conv)
    output = Dense(n_out, activation='softmax')(fc)
    
    # Build model
    model = Model(inputs=[X_in, A_in], outputs=output)
    optimizer = Adam(lr=learning_rate)
    model.compile(optimizer=optimizer,
                  loss='sparse_categorical_crossentropy',
                  metrics=['acc'])
    return model


def GCN(A, N=28*28, F=1,
        n_out=10,
        l2_reg=l2_reg, 
        learning_rate=learning_rate,
       ):
    '''Build a graph convolution network given A.
    '''
    # Computes a normalized Laplacian (as the conv filter)
    L = normalized_laplacian(A)
    
    # Model definition
    # N: Number of nodes in the graphs
    # F: Node features dimensionality
    X_in = Input(shape=(N, F))
    
    # Pass A as a fixed tensor, otherwise Keras will complain about inputs of
    # different rank.
    A_in = Input(tensor=sp_matrix_to_sp_tensor(L))

    graph_conv = GraphConv(32,
                           activation='relu',
                           kernel_regularizer=l2(l2_reg),
                           use_bias=True)([X_in, A_in])
    graph_conv = GraphConv(32,
                           activation='relu',
                           kernel_regularizer=l2(l2_reg),
                           use_bias=True)([graph_conv, A_in])
    
    rs = Reshape((28, 28, 32))(graph_conv)
    pooled = MaxPooling2D(pool_size=(2, 2))(rs)
    flatten = Flatten()(pooled)
    fc = Dense(512, activation='relu')(flatten)
    output = Dense(n_out, activation='softmax')(fc)
    
    # Build model
    model = Model(inputs=[X_in, A_in], outputs=output)
    optimizer = Adam(lr=learning_rate)
    model.compile(optimizer=optimizer,
                  loss='sparse_categorical_crossentropy',
                  metrics=['acc'])
    return model

# GCN Model with full grid

In [None]:
# original model params
# model.summary()

In [None]:
print(A.nnz)
model_full_grid = GCN_single_layer(A)
model_full_grid.summary()

In [None]:
# Train model
validation_data = (X_val, y_val)
model_full_grid.fit(X_train,
                    y_train,
                    batch_size=batch_size,
                    validation_data=validation_data,
                    epochs=epochs)

In [None]:
# Evaluate model
print('Evaluating model.')
eval_results = model_full_grid.evaluate(X_test,
                              y_test,
                              batch_size=batch_size)
print('Done.\n'
      'Test loss: {}\n'
      'Test acc: {}'.format(*eval_results))

# GCN model with an empty adjacency matrix

In [None]:
A0 = sp.csr_matrix(A.shape, dtype=np.float32)
print(A0.shape)

model_no_graph = GCN_single_layer(A0)
model_no_graph.summary()

In [None]:
model_no_graph.fit(X_train,
                    y_train,
                    batch_size=batch_size,
                    validation_data=validation_data,
                    epochs=epochs)

In [None]:
# Evaluate model
print('Evaluating model.')
eval_results = model_no_graph.evaluate(X_test,
                              y_test,
                              batch_size=batch_size)
print('Done.\n'
      'Test loss: {}\n'
      'Test acc: {}'.format(*eval_results))

In [None]:
def fc_model(N=28*28, F=1,
                     n_out=10,
                     l2_reg=l2_reg, 
                     learning_rate=learning_rate):
    
    X_in = Input(shape=(N, F))
    
    fc = Dense(10, activation='relu',
               kernel_regularizer=l2(l2_reg),
               use_bias=True)(Flatten()(X_in))
    
    output = Dense(n_out, activation='softmax')(fc)
    
    # Build model
    model = Model(inputs=X_in, outputs=output)
    optimizer = Adam(lr=learning_rate)
    model.compile(optimizer=optimizer,
                  loss='sparse_categorical_crossentropy',
                  metrics=['acc'])
    return model
    

In [None]:
model_fc = fc_model()
model_fc.summary()

In [None]:
model_fc.fit(X_train,
                    y_train,
                    batch_size=batch_size,
                    validation_data=validation_data,
                    epochs=epochs)

In [None]:
# Evaluate model
print('Evaluating model.')
eval_results = model_fc.evaluate(X_test,
                              y_test,
                              batch_size=batch_size)
print('Done.\n'
      'Test loss: {}\n'
      'Test acc: {}'.format(*eval_results))

# Models with digit feature graph as conv filter

In [None]:
X_train[mask].reshape(-1, 784).shape

In [None]:
d = metrics.pairwise_distances(X_train[mask].reshape(-1, 784).T, metric='cosine', n_jobs=-2)
print(d.shape)

In [None]:
from scipy.spatial.distance import cosine

In [None]:
d.min(), d.max()

In [None]:
plt.hist(d.flatten(), bins=50);

In [None]:
W = 1 - d
W = sp.coo_matrix(W, dtype=np.float32)

# No self-connections.
W.setdiag(0)

# Non-directed graph.
bigger = W.T > W
W = W - W.multiply(bigger) + W.T.multiply(bigger)

assert W.nnz % 2 == 0
assert np.abs(W - W.T).mean() < 1e-10
assert type(W) is sp.csr.csr_matrix

In [None]:
W

In [None]:
W = W.multiply(W > 0.8)
W

In [None]:
mask = y_train == 7
d, idx = graph.distance_sklearn_metrics(X_train[mask].reshape(-1, 784).T, k=4, 
                                        metric='cosine'
                                       )
print(d.shape, idx.shape)

In [None]:
# W = graph.adjacency(d, idx)

In [None]:
M, k = d.shape

# Weight matrix.
I = np.arange(0, M).repeat(k)
J = idx.reshape(M*k)
# J = np.arange(0, M).repeat(k)
V = (1-d).reshape(M*k)
nnz_mask = V > 0
W = sp.coo_matrix((V[nnz_mask], (I[nnz_mask], J[nnz_mask])), shape=(M, M))

# No self-connections.
W.setdiag(0)

# Non-directed graph.
bigger = W.T > W
W = W - W.multiply(bigger) + W.T.multiply(bigger)

assert W.nnz % 2 == 0
assert np.abs(W - W.T).mean() < 1e-10
assert type(W) is sp.csr.csr_matrix

In [None]:
V.min(), V.max()

In [None]:
W

In [None]:
W.getnnz()

In [None]:
plt.hist(W.toarray().flatten(), bins=50, log=True);

In [None]:
# A = graph.adjacency(d, idx)
print(W.shape, W.nnz)

In [None]:
fig, ax = plt.subplots()
draw_graph(W, ax=ax, size_factor=1)

In [None]:
fig, ax = plt.subplots()
draw_graph(W, ax=ax, size_factor=1, spring_layout=True)

In [None]:
d_digit_corr_graphs = {} # build digit feature graph by correlation

# this way of constructing feature graphs enable 
# the GCN to not only see pixels locally, but also globally based on known patterns

for i in range(10):
    mask = y_train == i
    
    dist = metrics.pairwise_distances(X_train[mask].reshape(-1, 784).T, metric='cosine', n_jobs=-2)
    
    W = sp.coo_matrix(1 - dist, dtype=np.float32)

    # No self-connections.
    W.setdiag(0)

    # Non-directed graph.
    bigger = W.T > W
    W = W - W.multiply(bigger) + W.T.multiply(bigger)

    assert W.nnz % 2 == 0
    assert np.abs(W - W.T).mean() < 1e-10
    assert type(W) is sp.csr.csr_matrix    
    
    
    fig, axes = plt.subplots(figsize=(15, 5), ncols=3)

    x_train_i_avg = X_train[mask].mean(axis=0).flatten()
    axes[0].imshow(x_train_i_avg.reshape(28, 28))

    # thresholding 
    W = W.multiply(W > 0.8)

    d_digit_corr_graphs[i] = W
    
    axes[1] = draw_graph(W, ax=axes[1], size_factor=1)
    
    axes[2] = draw_graph(W, ax=axes[2], size_factor=1, spring_layout=True)
    fig.tight_layout()
    plt.show()

In [None]:
dist = metrics.pairwise_distances(X_train.reshape(-1, 784).T, metric='cosine', n_jobs=-2)

W = sp.coo_matrix(1 - dist, dtype=np.float32)

# No self-connections.
W.setdiag(0)

# Non-directed graph.
bigger = W.T > W
W = W - W.multiply(bigger) + W.T.multiply(bigger)

assert W.nnz % 2 == 0
assert np.abs(W - W.T).mean() < 1e-10
assert type(W) is sp.csr.csr_matrix    


fig, axes = plt.subplots(figsize=(15, 5), ncols=3)

x_train_i_avg = X_train.mean(axis=0).flatten()
axes[0].imshow(x_train_i_avg.reshape(28, 28))

# thresholding 
W = W.multiply(W > 0.8)

d_digit_corr_graphs[i] = W

axes[1] = draw_graph(W, ax=axes[1], size_factor=1)

axes[2] = draw_graph(W, ax=axes[2], size_factor=1, spring_layout=True)
fig.tight_layout()
plt.show()

In [None]:
d_digit_models = {}
for digit in range(10):
#     model_i = GCN_single_layer(d_digit_graphs[digit])
    model_i = GCN_single_layer(d_digit_corr_graphs[digit])
    print(digit, d_digit_graphs[digit].nnz)
#     model_0.summary()

    # Train model with digit feature graph
    model_i.fit(X_train,
                        y_train,
                        batch_size=batch_size,
                        validation_data=validation_data,
                        epochs=epochs)
    
    d_digit_models[digit] = model_i

In [None]:
for digit, model_i in d_digit_models.items():
    eval_results = model_i.evaluate(X_test,
                                  y_test,
                                  batch_size=batch_size)
    print('Digit %d' % digit)
    print('Test loss: {}\n'
          'Test acc: {}'.format(*eval_results))

In [None]:
from sklearn import metrics
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import pandas as pd
def plot_confusion_matrix(cm, classes=list(range(10))):
    cm_df = pd.DataFrame(cm, index=classes, columns=classes)
    fig, ax = plt.subplots(figsize=(8, 8))
    ax = sns.heatmap(cm_df, 
#                      fmt='d', 
                     fmt='.3f',
                     annot=True, cmap='Reds', ax=ax)
    ax.set_ylabel('True label')
    ax.set_xlabel('Predicted label')
    fig.tight_layout()
    return fig

In [None]:
y_test_preds = model_full_grid.predict(X_test)
print(y_test_preds.shape)
cm = metrics.confusion_matrix(y_test, np.argmax(y_test_preds, axis=1))
fig = plot_confusion_matrix(cm/cm.sum(axis=1))
fig.get_axes()[0].set_title('Full grid graph');

In [None]:
y_test_preds = model_no_graph.predict(X_test)
print(y_test_preds.shape)
cm = metrics.confusion_matrix(y_test, np.argmax(y_test_preds, axis=1))
fig = plot_confusion_matrix(cm/cm.sum(axis=1))
fig.get_axes()[0].set_title('No graph');

In [None]:
acc_df = {}

for model_name, model in d_digit_models.items():
    
    y_test_preds = model.predict(X_test)
    cm = metrics.confusion_matrix(y_test, np.argmax(y_test_preds, axis=1))
    
#     y_train_preds = model.predict(X_train)
#     cm = metrics.confusion_matrix(y_train, np.argmax(y_train_preds, axis=1))
    
#     acc_per_classes = np.diag(cm/cm.sum(axis=1))
    acc_per_classes = np.diag(cm)
    
    acc_df[model_name] = acc_per_classes

In [None]:
acc_df = pd.DataFrame.from_dict(acc_df)
acc_df

In [None]:
y_test_preds = model_full_grid.predict(X_test)
cm = metrics.confusion_matrix(y_test, np.argmax(y_test_preds, axis=1))

# y_train_preds = model_full_grid.predict(X_train)
# cm = metrics.confusion_matrix(y_train, np.argmax(y_train_preds, axis=1))


# acc_per_class_full_model = np.diag(cm/cm.sum(axis=1))
acc_per_class_full_model = np.diag(cm)
acc_per_class_full_model

In [None]:
# accuracy gain compared to full model
sns.heatmap(acc_df - acc_per_class_full_model,
            cmap='RdBu_r',
#             vmin=-0.15,
#             vmax=0.15
           )

In [None]:
# accuracy gain compared to 10 averaged models
sns.heatmap(acc_df - acc_df.mean(axis=1),
            cmap='RdBu_r',
#             vmin=-0.15,
#             vmax=0.15
           )

In [None]:
y_test_preds = model_no_graph.predict(X_test)
cm = metrics.confusion_matrix(y_test, np.argmax(y_test_preds, axis=1))


acc_per_class_no_graph = np.diag(cm/cm.sum(axis=0))

# accuracy gain compared to no graph model
sns.heatmap(acc_df - acc_per_class_no_graph,
            cmap='RdBu_r',
            vmin=-0.15,
            vmax=0.15
           )

In [None]:
y_test_preds = model_fc.predict(X_test)
cm = metrics.confusion_matrix(y_test, np.argmax(y_test_preds, axis=1))


acc_per_class_fc = np.diag(cm/cm.sum(axis=0))

# accuracy gain compared to FC model
sns.heatmap(acc_df - acc_per_class_fc,
            cmap='RdBu_r',
            vmin=-0.15,
            vmax=0.15
           )