In [1]:
from torch_geometric.utils.convert import from_scipy_sparse_matrix, to_scipy_sparse_matrix
from scipy.io import loadmat

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch_sparse
import torch

from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.loader import ClusterData, ClusterLoader

from sklearn.metrics import roc_auc_score

from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cuda device


In [3]:
data = loadmat('ACM.mat')

In [4]:
X = data['Attributes']
A = data['Network']
labels = data['Label']

X = torch.from_numpy(X.todense()).float() 
edge_index, edge_weight = from_scipy_sparse_matrix(A)

In [5]:
data = Data(x=X, edge_index=edge_index, edge_weight=edge_weight, y=labels)

In [6]:
X.shape, A[0].shape, A[1].shape, labels.shape

(torch.Size([16484, 8337]), (1, 16484), (1, 16484), (16484, 1))

In [7]:
torch.unique(X)

tensor([0.0000, 0.0414, 0.0452, 0.0460, 0.0474, 0.0476, 0.0476, 0.0478, 0.0480,
        0.0483, 0.0484, 0.0487, 0.0489, 0.0491, 0.0492, 0.0493, 0.0494, 0.0496,
        0.0497, 0.0500, 0.0501, 0.0502, 0.0503, 0.0504, 0.0505, 0.0507, 0.0510,
        0.0511, 0.0513, 0.0514, 0.0515, 0.0516, 0.0516, 0.0517, 0.0518, 0.0519,
        0.0520, 0.0521, 0.0521, 0.0522, 0.0523, 0.0523, 0.0525, 0.0526, 0.0526,
        0.0528, 0.0529, 0.0531, 0.0531, 0.0532, 0.0533, 0.0534, 0.0535, 0.0536,
        0.0538, 0.0538, 0.0539, 0.0540, 0.0541, 0.0542, 0.0542, 0.0543, 0.0544,
        0.0545, 0.0546, 0.0546, 0.0547, 0.0548, 0.0549, 0.0550, 0.0551, 0.0552,
        0.0553, 0.0554, 0.0555, 0.0556, 0.0556, 0.0557, 0.0558, 0.0559, 0.0560,
        0.0561, 0.0562, 0.0563, 0.0563, 0.0564, 0.0565, 0.0566, 0.0567, 0.0568,
        0.0569, 0.0570, 0.0571, 0.0572, 0.0573, 0.0574, 0.0574, 0.0575, 0.0576,
        0.0577, 0.0578, 0.0579, 0.0581, 0.0582, 0.0583, 0.0584, 0.0585, 0.0586,
        0.0587, 0.0588, 0.0589, 0.0590, 

In [8]:
class GAE_Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels=128, out_channels=64):
        super(GAE_Encoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        return x
    
dummy_input = torch.rand(100, 50)
dummy_edge_index = torch.randint(0, 100, (2, 100))
encoder = GAE_Encoder(50, 128, 64)
output = encoder(dummy_input, dummy_edge_index)
output.shape

torch.Size([100, 64])

In [9]:
class GAE_AttrDecoder(nn.Module):
    def __init__(self, in_channels, hidden_channels=64, out_channels=128):
        super(GAE_AttrDecoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        
    def forward(self, z, edge_index):
        x_hat = F.relu(self.conv1(z, edge_index))
        x_hat = F.relu(self.conv2(x_hat, edge_index))
        return x_hat
    
dummy_input = torch.rand(100, 64)
decoder = GAE_AttrDecoder(64, 64, 128)
output = decoder(dummy_input, dummy_edge_index)
output.shape

torch.Size([100, 128])

In [10]:
class GAE_StructDecoder(nn.Module):
    def __init__(self, channels=64):
        super(GAE_StructDecoder, self).__init__()
        self.conv = GCNConv(channels, channels)
        
    def forward(self, z, edge_index):
        z_hat = F.relu(self.conv(z, edge_index))
        return z_hat @ z_hat.t()

dummy_input = torch.rand(100, 64)
decoder = GAE_StructDecoder(64)
output = decoder(dummy_input, dummy_edge_index)
output.shape

torch.Size([100, 100])

In [11]:
class GraphAutoencoder(nn.Module):
    def __init__(self, in_channels, hidden_channels=128, out_channels=64):
        super(GraphAutoencoder, self).__init__()
        self.encoder = GAE_Encoder(in_channels, hidden_channels, out_channels)
        self.attr_decoder = GAE_AttrDecoder(in_channels=out_channels, hidden_channels=hidden_channels, out_channels=in_channels)
        self.struct_decoder = GAE_StructDecoder(channels=out_channels)

    def forward(self, x, edge_index):
        z = self.encoder(x, edge_index)
        
        X_hat = self.attr_decoder(z, edge_index)
        
        A_hat = self.struct_decoder(z, edge_index)
        
        return X_hat, A_hat, z
    
dummy_input = torch.rand(100, 50)
dummy_edge_index = torch.randint(0, 100, (2, 100))
model = GraphAutoencoder(50, 128, 64)
X_hat, A_hat, z = model(dummy_input, dummy_edge_index)
X_hat.shape, A_hat.shape, z.shape

(torch.Size([100, 50]), torch.Size([100, 100]), torch.Size([100, 64]))

In [12]:
def gae_loss(x, x_hat, a, a_hat, alpha=0.8):
    attr_diff = x - x_hat
    attr_loss = torch.norm(attr_diff, p='fro')**2   
    
    struct_diff = a - a_hat
    struct_loss = torch.norm(struct_diff, p='fro')**2
    
    return alpha * attr_loss + (1 - alpha) * struct_loss

In [13]:
def compute_node_reconstruction_errors(x, x_hat, A, A_hat):
    attr_errors = torch.sum((x - x_hat)**2, dim=1)
    struct_errors = torch.sum((A - A_hat)**2, dim=1)

    total_errors = attr_errors + struct_errors
    return total_errors

In [14]:
loader = NeighborLoader(data, num_neighbors=[15, 10], batch_size=32, shuffle=True)



In [None]:
from torch.optim import Adam

model = GraphAutoencoder(in_channels=X.shape[1], hidden_channels=128, out_channels=64)
model = model.to(device) 
optimizer = Adam(model.parameters(), lr=0.004)

num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    for batch in tqdm(loader):  
        optimizer.zero_grad()
        
        batch_x = batch.x.to(device)
        batch_edge_index = batch.edge_index.to(device)
        
        x_hat, a_hat, z = model(batch_x, batch_edge_index)
        
        a = to_scipy_sparse_matrix(batch_edge_index).todense()
        a = torch.from_numpy(a).float().to(device)
        
        loss = gae_loss(batch_x, x_hat, a, a_hat)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    if (epoch + 1) % 5 == 0:
        model.eval()
        with torch.no_grad():
            x_hat, a_hat, z = model(data.x.to(device), data.edge_index.to(device))
            loss = gae_loss(data.x.to(device), x_hat, a, a_hat)
            errors = compute_node_reconstruction_errors(data.x.to(device), x_hat, a, a_hat)
            errors = errors.cpu().numpy()
            labels = data.y.cpu().numpy()
            auc = roc_auc_score(labels, errors)
    
    print(f"Epoch {epoch}, Loss: {total_loss:.4f}")


100%|██████████| 516/516 [00:51<00:00, 10.06it/s]


Epoch 0, Loss: 531079.8444


100%|██████████| 516/516 [00:50<00:00, 10.27it/s]


Epoch 1, Loss: 527126.2182


100%|██████████| 516/516 [00:50<00:00, 10.20it/s]


Epoch 2, Loss: 526235.4852


100%|██████████| 516/516 [00:51<00:00, 10.08it/s]


Epoch 3, Loss: 523778.7846


100%|██████████| 516/516 [00:50<00:00, 10.14it/s]

Epoch 4, Loss: 522340.9470



