In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree

class GCNConvPearson(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConvPearson, self).__init__(aggr='add')  # We will manually handle aggregation
        self.weight = torch.nn.Parameter(torch.Tensor(in_channels, out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)
    
    def forward(self, x, edge_index):
        # Compute Pearson correlation
        edge_index_i, edge_index_j = edge_index[0], edge_index[1]  # Node i to j
        
        # Calculate mean of node features for each node
        x_mean = torch.mean(x, dim=1, keepdim=True)
        
        # Subtract the mean
        x_centered = x - x_mean

        # Message passing (to neighbors)
        return self.propagate(edge_index, x=x_centered, norm=x_mean)
    
    def message(self, x_j, x_i, norm):
        # Pearson correlation numerator: sum((x_i - mean(x_i)) * (x_j - mean(x_j)))
        numerator = torch.sum(x_i * x_j, dim=-1, keepdim=True)

        # Denominator: sqrt(sum((x_i - mean(x_i))^2) * sum((x_j - mean(x_j))^2))
        x_i_norm = torch.sqrt(torch.sum(x_i**2, dim=-1, keepdim=True))
        x_j_norm = torch.sqrt(torch.sum(x_j**2, dim=-1, keepdim=True))
        
        denominator = x_i_norm * x_j_norm + 1e-7  # Adding a small value to avoid division by zero

        # Pearson correlation as weight
        pearson_corr = numerator / denominator

        # Weight the features of the neighbors by the correlation coefficient
        return pearson_corr * x_j
    
    def update(self, aggr_out):
        # Apply the learned weights to the aggregated feature
        return aggr_out @ self.weight


In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import Data

# Custom GCN with Pearson Correlation and Global Pooling for Graph-Level Prediction
class GCNWithPearsonAggregationGraphLevel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNWithPearsonAggregationGraphLevel, self).__init__()
        self.conv1 = GCNConvPearson(in_channels, hidden_channels)
        self.conv2 = GCNConvPearson(hidden_channels, hidden_channels)
        self.fc = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        # First GCN layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        
        # Second GCN layer
        x = self.conv2(x, edge_index)
        
        # Pooling layer (global mean pooling to aggregate node features into graph-level representation)
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
        
        # Fully connected layer for final prediction (graph-level)
        x = self.fc(x)
        
        return F.log_softmax(x, dim=-1)


In [None]:
from torch_geometric.data import Data, Batch
from torch_geometric.utils import dense_to_sparse

# Create a batch of fully connected graphs, each with 50 nodes and 144 features
num_nodes = 50
num_features = 144
num_classes = 6
graph_train_data = torch.load('graph_train_data.pth')
graph_test_data = torch.load('graph_test_data.pth')


# Create a batch of graphs
train_batch = Batch.from_data_list(graph_train_data)
test_batch = Batch.from_data_list(graph_test_data)


In [None]:
# Initialize the model, optimizer, and loss function for graph-level prediction
model = GCNWithPearsonAggregationGraphLevel(in_channels=num_features, hidden_channels=64, out_channels=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()

# Training loop for graph-level prediction
model.train()
for epoch in range(10):
    optimizer.zero_grad()
    
    # Forward pass: use batch data for graph-level prediction
    out = model(train_batch.x, train_batch.edge_index, train_batch.batch)
    
    # Compute loss using graph-level labels
    loss = loss_fn(out, train_batch.y)  # batch.y contains the labels for the entire graphs
    loss.backward()
    optimizer.step()

    # if epoch % 10 == 0:
    print(f'Epoch {epoch}, Loss: {loss.item()}')


In [None]:

def calculate_accuracy(pred, target):
    """Calculate accuracy by comparing predicted and true labels."""
    correct = pred.eq(target).sum().item()
    accuracy = correct / target.size(0)
    return accuracy


model.eval()
output = model(test_batch.x, test_batch.edge_index, test_batch.batch)
predicted_classes = output.argmax(dim=1) 

final_accuracy = calculate_accuracy(predicted_classes, test_batch.y)
print(f'Final Accuracy Testing: {final_accuracy:.4f}')


output = model(train_batch.x, train_batch.edge_index, train_batch.batch)
predicted_classes = output.argmax(dim=1) 
final_accuracy = calculate_accuracy(predicted_classes, test_batch.y)
print(f'Final Accuracy Testing: {final_accuracy:.4f}')