## GCN (PyTorch) with random graphs

In [None]:
import torch
import pytorch_gcn as gcn
import matplotlib.pyplot as plt

from torch import nn
from pprint import pprint

In [None]:
# torch.manual_seed(19870314)

### Generate a random graph

In [None]:
num_nodes = 16
edges = []

degrees = torch.randint(1, 3, (num_nodes,))
for v, num_neighbors in enumerate(degrees):
    for u in torch.randint(0, num_nodes, (num_neighbors,)):
        if v == u:
            continue
        edges.append((v, u.item()))

print(f"nodes: {num_nodes} {list(range(num_nodes))}")
print(f"edges: {len(edges)} {edges}")

### Build adjacency matrix

In [None]:
adj = torch.zeros(num_nodes, num_nodes)
for v, u in edges:
    # undirected edges
    adj[v][u] = 1
    adj[u][v] = 1

print(f"adjacency: {adj.shape}")
print(adj)

### Assign ground truth labels

In [None]:
num_classes = 3
labels = torch.randint(0, num_classes, (num_nodes,))
print(f"labels: {labels}")

### Plot as points to visualize

In [None]:
def plot_graph(adj, points, labels):
    for u, row in enumerate(adj):
        for v, connected in enumerate(row):
            if connected == 0:
                continue
            point_u = points[u]
            point_v = points[v]
            xs = [point_u[0], point_v[0]]
            ys = [point_u[1], point_v[1]]
            plt.plot(xs, ys, linewidth=1, c='gray')
    num_nodes = adj.shape[0]
    area = (torch.zeros(num_nodes) + 10) ** 2
    xs = points[:,0]
    ys = points[:,1]
    plt.scatter(xs, ys, s=area, c=labels)
    plt.show()

In [None]:
points = torch.rand(num_nodes, 2)
print(f"points:\n{points}")

In [None]:
plot_graph(adj, points, labels)

### Setup GCN model

In [None]:
model = gcn.GraphConvNet(adj, 2, num_nodes, num_classes)
model

### Setup inputs nodes with no features

In [None]:
X = torch.eye(num_nodes)
# print(X)
print(X.shape)

### Feed inputs foward in model

In [None]:
logits = model(X)
print(logits)
print(logits.shape)

### Classify the nodes

In [None]:
softmax = nn.Softmax(dim=1)
probs = softmax(logits)
print(probs)
print(probs.shape)

### Check against ground truth

In [None]:
y_pred = probs.argmax(dim=1)
correct = (y_pred == labels).type(torch.float).sum().item()
accuracy = (correct / num_nodes) * 100

print(f"y_predicted: {y_pred}")
print(f"true_labels: {labels}")
print(f"accuracy: {accuracy:>0.1f}%")

### Try to plot it

In [None]:
plot_graph(adj, logits.detach().numpy(), y_pred)

### Let's do some training

In [None]:
batch_size = num_nodes
learning_rate = 0.01
epochs = 200

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for i in range(epochs):
    print(f"Epoch {i+1:<3} ------")
    gcn.train(X, labels, model, loss_fn, optimizer)
    gcn.evaluate(X, labels, model)
    print()
print("done")

### Now try visualizing it

In [None]:
logits = model(X)
y_pred = softmax(logits).argmax(dim=1)

plot_graph(adj, logits.detach().numpy(), y_pred)