In [0]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import TensorDataset

dtype = torch.long
batch_size = 128
learning_rate = 0.0001
num_epochs = 50
n_nodes = 5
hidden_dim = 10
n_classes = 4
dropout = 0.5

def Normalize_Adj(A):
    A_tilda = A + torch.eye(A.shape[1]).repeat(A.shape[0], 1, 1)
    D_tilda = torch.diag_embed(torch.sum(A_tilda, 2).pow(-0.5))
    A_hat = D_tilda.bmm(A_tilda).bmm(D_tilda)
    return A_hat

def get_dataset(n_train=65536, n_valid=8192, n_nodes=n_nodes, n_classes=n_classes):
    # Generate random adjacency matrices
    A = torch.randint(2, [n_train + n_valid, n_nodes, n_nodes])
    A = A.to(dtype)
    upper_tr = torch.triu(A, diagonal=1)
    data =  upper_tr + torch.transpose(upper_tr, 1, 2)
    data = Normalize_Adj(data) # Normalization
    data = torch.split(data, split_size_or_sections=[n_train, n_valid], dim=0)
    # Generating labels
    train_y = torch.randint(n_classes, (n_train, n_nodes), dtype=dtype)
    valid_y = torch.randint(n_classes, (n_valid, n_nodes), dtype=dtype)

    train_data = TensorDataset(data[0], train_y)
    valid_data = TensorDataset(data[1], valid_y)
    return train_data, valid_data

train_dataset, valid_dataset = get_dataset()
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)
X = torch.eye(n_nodes)

                             ##############

class GraphConvolutionLayer(nn.Module):
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.fc = nn.Linear(input_dim, output_dim, bias=False)
    torch.nn.init.xavier_uniform_(self.fc.weight)
  def forward(self, X, A):
    out = self.fc(X)
    out = torch.bmm(A, out)  
    return out

class TwoLayerGCN(nn.Module):
  def __init__(self, input_dim=n_nodes, hidden_dim=hidden_dim, n_classes=n_classes, dropout=dropout):
    super().__init__()
    self.gc1 = GraphConvolutionLayer(input_dim, hidden_dim)
    self.gc2 = GraphConvolutionLayer(hidden_dim, n_classes)
    self.dropout = dropout

  def forward(self, X, A):
    X = X.repeat(A.shape[0], 1, 1)
    #print('X: ', X.shape)
    out = self.gc1(X, A)
    #print('Output after 1st GCN layer: ', out.shape)
    out = F.relu(out)
    out = F.dropout(out, self.dropout)
    out = self.gc2(out, A)
    #print('Output after 2nd GCN layer: ', out.shape)
    #out = F.softmax(out) 
    return out

model = TwoLayerGCN()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (A, labels) in enumerate(train_loader):
        
        # Forward pass
        outputs = model(X, A)
        loss = criterion(outputs.transpose(1,2), labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 8 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss))

with torch.no_grad():
    correct = 0
    total = 0
    for A, labels in valid_loader:
        outputs = model(X, A)
        _, predicted = torch.max(outputs, 2)
        total += labels.numel()
        correct += (predicted == labels).sum()

    print('Accuracy of the network on the test data: {} %'.format(100 * correct / total))


Epoch [1/50], Step [8/512], Loss: 1.4101
Epoch [1/50], Step [16/512], Loss: 1.4020
Epoch [1/50], Step [24/512], Loss: 1.3996
Epoch [1/50], Step [32/512], Loss: 1.4092
Epoch [1/50], Step [40/512], Loss: 1.3967
Epoch [1/50], Step [48/512], Loss: 1.4024
Epoch [1/50], Step [56/512], Loss: 1.3967
Epoch [1/50], Step [64/512], Loss: 1.4090
Epoch [1/50], Step [72/512], Loss: 1.3973
Epoch [1/50], Step [80/512], Loss: 1.3938
Epoch [1/50], Step [88/512], Loss: 1.4014
Epoch [1/50], Step [96/512], Loss: 1.3905
Epoch [1/50], Step [104/512], Loss: 1.3965
Epoch [1/50], Step [112/512], Loss: 1.3962
Epoch [1/50], Step [120/512], Loss: 1.3920
Epoch [1/50], Step [128/512], Loss: 1.4039
Epoch [1/50], Step [136/512], Loss: 1.3997
Epoch [1/50], Step [144/512], Loss: 1.3972
Epoch [1/50], Step [152/512], Loss: 1.4010
Epoch [1/50], Step [160/512], Loss: 1.4068
Epoch [1/50], Step [168/512], Loss: 1.4023
Epoch [1/50], Step [176/512], Loss: 1.3876
Epoch [1/50], Step [184/512], Loss: 1.3897
Epoch [1/50], Step [192/