## PyTorch GCN with CORA dataset

In [None]:
import torch
import pytorch_gcn as gcn
import matplotlib.pyplot as plt
import pandas as pd

from torch import nn
from pprint import pprint

In [None]:
df_nodes = pd.read_csv('data/cora-nodes.csv')
df_edges = pd.read_csv('data/cora-edges.csv')
df_labels = pd.read_csv('data/cora-labels.csv')

### Setup nodes

In [None]:
df_nodes = df_nodes.sort_values('paper_id')
df_nodes

In [None]:
nodes = df_nodes.iloc[:,1:].to_numpy()
nodes = torch.from_numpy(nodes).type(torch.float)
num_nodes = nodes.shape[0]

print(f"nodes: {nodes.shape}")
print(nodes)

### Setup edges

In [None]:
df_edges

In [None]:
adj = torch.zeros(num_nodes, num_nodes)
for row in df_edges.itertuples():
    src_id = row.citing_paper_id
    dst_id = row.cited_paper_id
    src_idx = df_nodes.index[df_nodes['paper_id'] == src_id].item()
    dst_idx = df_nodes.index[df_nodes['paper_id'] == dst_id].item()
    # using undirected edges
    adj[src_idx][dst_idx] = 1
    adj[dst_idx][src_idx] = 1

print(f"adjacency: {adj.shape}")
print(adj)

### Setup labels

In [None]:
label_mapping = sorted(df_labels['class_label'].unique())
label_mapping = {name: i for i, name in enumerate(label_mapping)}
label_mapping

In [None]:
num_classes = len(label_mapping)
labels = torch.zeros(num_nodes).type(torch.long)
for row in df_labels.itertuples():
    paper_id = row.paper_id
    node_id = df_nodes.index[df_nodes['paper_id'] == paper_id].item()
    label_id = label_mapping[row.class_label]
    labels[node_id] = label_id

print(f"labels: {labels.shape}")
print(labels[:15], '...')

### Attempt to visualize graph

In [None]:
def plot_graph(adj, points, labels):
    for u, row in enumerate(adj):
        for v, connected in enumerate(row):
            if connected == 0:
                continue
            point_u = points[u]
            point_v = points[v]
            xs = [point_u[0], point_v[0]]
            ys = [point_u[1], point_v[1]]
            plt.plot(xs, ys, linewidth=1, c='gray')
    num_nodes = adj.shape[0]
    area = (torch.zeros(num_nodes) + 10) ** 2
    xs = points[:,0]
    ys = points[:,1]
    plt.scatter(xs, ys, s=area, c=labels)
    plt.show()

points = torch.rand(num_nodes, 2)
print(f"points:\n{points}")

plot_graph(adj[:25,:25], points[:25,:], labels[:25])

### Setup model

In [None]:
in_features = nodes.shape[1]
model = gcn.GraphConvNet(adj, 2, in_features, num_classes)
model

### Try a single feed forward

In [None]:
logits = model(nodes)
print(f"logits: {logits.shape}")
print(logits)

softmax = nn.Softmax(dim=1)
probs = softmax(logits)
print(f"probabilities: {probs.shape}")
print(probs)

y_pred = probs.argmax(dim=1)
correct = (y_pred == labels).type(torch.float).sum().item()
accuracy = (correct / num_nodes) * 100
print(f"y_predicted: {y_pred[:15]}")
print(f"true_labels: {labels[:15]}")
print(f"accuracy: {accuracy:>0.1f}%")

plot_graph(adj[:25,:25], logits.detach().numpy()[:25,:], y_pred[:25])

### Training

In [None]:
batch_size = num_nodes
learning_rate = 0.01
epochs = 200

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for i in range(epochs):
    print(f"Epoch {i+1:<3} ------")
    gcn.train(nodes, labels, model, loss_fn, optimizer)
    gcn.evaluate(nodes, labels, model)
    print()
print("done")

In [None]:
logits = model(nodes)
y_pred = softmax(logits).argmax(dim=1)
plot_graph(adj[:25,:25], logits.detach().numpy()[:25,:], y_pred[:25])