In [8]:
import sys
sys.path.append('../')

import numpy as np
import networkx as nx
import itertools
import torch

import torch.nn.functional as F
import torch_sparse
import random 
import matplotlib.pyplot as plt


from torch import nn
from models.difussion_models import GraphLaplacianDiffusion, DiagSheafDiffusion, MultiDimSheafDiffusion, DiscreteMultiDimSheafDiffusion
from lib.laplace import build_norm_sheaf_laplacian, remove_duplicate_edges, build_sheaf_difussion_matrix, dirichlet_energy
from torch_geometric.nn.dense.linear import Linear
from scipy import linalg
from torch_geometric.utils import to_dense_adj, from_networkx, degree
from torch_geometric.nn.conv import GCNConv
from data.heterophilic import get_dataset, generate_random_splits

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_dtype(torch.float32)
print("Device:", device)

Device: cuda


#### Initialise heterophilic dataset

In [9]:
dataset_name = 'texas'
# dataset_name = 'wisconsin'
# dataset_name = 'cornell'
# dataset_name = 'squirrel'
# dataset_name = 'chameleon'


dataset = get_dataset(dataset_name)

data = dataset[0]
print("Nodes:", data.x.size())
print("Edges:", data.edge_index.size(1))
print("Self loops:", (data.edge_index[0] == data.edge_index[1]).sum().item())
counts = np.unique(data.y, return_counts=True)
print("Counts", counts)
print("Max deg", degree(data.edge_index[0]).max().item())
print("Min deg", degree(data.edge_index[0]).min().item())

Nodes: torch.Size([183, 1703])
Edges: 558
Self loops: 0
Counts (array([0, 1, 2, 3, 4]), array([ 33,   1,  18, 101,  30]))
Max deg 104.0
Min deg 1.0


In [10]:
def eval_model(model, data, mask):
    model.eval()
    with torch.no_grad():
        out = model(data.x)
        pred = out.argmax(dim=1)
        correct = (pred[mask] == data.y[mask]).sum()
        acc = int(correct) / int(mask.sum())
        loss = F.nll_loss(out[mask], data.y[mask]).item()
        return acc, loss

In [11]:
def get_fixed_splits(data, dataset_name, seed):
    split = np.load(f'../splits/{dataset_name}_split_0.6_0.2_{seed}.npz')
    
    data.train_mask = torch.tensor(split['train_mask'], dtype=torch.bool)
    data.val_mask = torch.tensor(split['val_mask'], dtype=torch.bool)
    data.test_mask = torch.tensor(split['test_mask'], dtype=torch.bool)
    
    assert torch.count_nonzero(data.train_mask + data.val_mask + data.test_mask) == data.x.size(0)
    
    split.close()
    
    return data

In [5]:
# model = MultiDimSheafDiffusion(data.x.size(0), data.edge_index, 2, 0.1, data.x.size(1), hidden_dim=32, 
#                                    output_dim=dataset.num_classes, normalised=False, augmented=False, att_dim=10, dropout=0.0) 

model = DiscreteMultiDimSheafDiffusion(data.x.size(0), data.edge_index, 3, 2, data.x.size(1), hidden_dim=33, 
                                   output_dim=dataset.num_classes, normalised=False, augmented=False, nonlinear=False, att_dim=32, deg_normalised=True, dropout=0.0)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.size())

lin1.weight torch.Size([33, 2325])
lin1.bias torch.Size([33])
sheaf_learners.0.map_builder.weight torch.Size([3, 66])
lin2.weight torch.Size([5, 33])
lin2.bias torch.Size([5])


### Training with fixed time

In [None]:
seeds = 10
time = 0.0
patience = 200
epochs = 250

train_accs = []
test_accs = []
best_test_accs = []

for seed in range(0, seeds):
    print(f"===== Seed {seed} ======")
    dataset = get_dataset(dataset_name)
    data = get_fixed_splits(dataset[0], dataset_name, seed).to(device)    

    # model = DiagSheafDiffusion(data.x.size(0), data.edge_index, 5, time, data.x.size(1), hidden_dim=70, 
    #                            output_dim=dataset.num_classes, normalised=False, nonlinear=False, deg_normalised=True, dropout=0.5) 
    # model = MultiDimSheafDiffusion(data.x.size(0), data.edge_index, 2, time, data.x.size(1), hidden_dim=32, 
    #                                output_dim=dataset.num_classes, normalised=False, augmented=False, att_dim=32, dropout=0.0, nonlinear=True, deg_normalised=True) 
    model = DiscreteMultiDimSheafDiffusion(data.x.size(0), data.edge_index, 12, 3, data.x.size(1), hidden_dim=120, 
                                   output_dim=dataset.num_classes, normalised=False, augmented=False, nonlinear=True, att_dim=32, deg_normalised=True, dropout=0.7, 
                                   use_weights=False, act=False, bn=False) 
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.05, weight_decay=5e-4)
    # optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01, weight_decay=1e-4)

    old_val_acc = 0
    patience_counter = 0
    best_test_acc = 0
    best_epoch = 0
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
    
    for epoch in range(epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(data.x)
        
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        scheduler.step()
        
#         if epoch % 200 == 0:
#             print("Debug")
#             print("Max", model.L.max().item())
#             print("Min", model.L.min().item())

#             print("Avg", model.laplacian_builder.L[1].abs().mean().item())


            # assert model.odefunc.sheaf_learner.map_builder.weight.grad is not None
            # assert model.odefunc.sheaf_learner.linear2.weight.grad is not None

            # print("Grad", model.odefunc.sheaf_learner.map_builder.weight.grad.abs().mean())
            # print("Max:", model.odefunc.L[1].abs().max())
            # print("Mean:", model.odefunc.L[1].abs().mean())
        
        val_acc, _ = eval_model(model, data, data.val_mask)
        if val_acc > old_val_acc:
            old_val_acc = val_acc
            best_test_acc, _ = eval_model(model, data, data.test_mask)
            best_epoch = epoch
            patience_counter = 0
        else:
            patience_counter += 1 
            if patience_counter >= patience:
                break

        
        if epoch % 100 == 0 and epoch > 0:
            print(f"Epoch: {epoch} | Loss: {loss.item()}")
    
    train_acc, _ = eval_model(model, data, data.train_mask)
    val_acc, _ = eval_model(model, data, data.val_mask)
    test_acc, _ = eval_model(model, data, data.test_mask)
    
    print(f"Train acc: {train_acc}")
    print(f"Val acc: {val_acc}")
    print(f"Test acc: {test_acc}")
    print(f"Best test acc: {best_test_acc}")
    print(f"Best epoch: {best_epoch}")
    
    train_accs.append(train_acc)
    test_accs.append(test_acc)
    best_test_accs.append(best_test_acc)

Epoch: 100 | Loss: 0.10687236487865448
Epoch: 200 | Loss: 0.10275166481733322
Train acc: 1.0
Val acc: 0.7627118644067796
Test acc: 0.7837837837837838
Best test acc: 0.7837837837837838
Best epoch: 87
Epoch: 100 | Loss: 0.10657243430614471
Epoch: 200 | Loss: 0.07840995490550995
Train acc: 1.0
Val acc: 0.7796610169491526
Test acc: 0.8918918918918919
Best test acc: 0.918918918918919
Best epoch: 91
Epoch: 100 | Loss: 0.08386233448982239
Epoch: 200 | Loss: 0.09970902651548386
Train acc: 1.0
Val acc: 0.7966101694915254
Test acc: 0.7297297297297297
Best test acc: 0.7297297297297297
Best epoch: 46
Epoch: 100 | Loss: 0.18335121870040894
Epoch: 200 | Loss: 0.06836207956075668
Train acc: 1.0
Val acc: 0.8135593220338984
Test acc: 0.8108108108108109
Best test acc: 0.8648648648648649
Best epoch: 61
Epoch: 100 | Loss: 0.10896273702383041
Epoch: 200 | Loss: 0.08449429273605347
Train acc: 1.0
Val acc: 0.864406779661017
Test acc: 0.8108108108108109
Best test acc: 0.8378378378378378
Best epoch: 79
Epoch: 

In [50]:
print(f"======== {dataset_name} ========")
print(f"Early stop acc: {np.mean(test_accs):.4f} +/- {np.std(test_accs):.4f}")
print(f"Best val acc  : {np.mean(best_test_accs):.4f} +/- {np.std(best_test_accs):.4f}")

Early stop acc: 0.8270 +/- 0.0422
Best val acc  : 0.8000 +/- 0.0664
