GAT:
https://arxiv.org/pdf/1710.10903.pdf

GATv2:

https://arxiv.org/pdf/2105.14491.pdf

https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GATv2Conv.html?highlight=gatv2

In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv, GATv2Conv
from torch_geometric.loader import DataLoader


In [2]:
dataset = Planetoid(root='./tmp/Cora', name='Cora')

In [3]:
class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GAT, self).__init__()
        self.conv1 = GATv2Conv(in_channels, 8, heads=8, dropout=0.6)
        # On the Pubmed dataset, use heads=8 in conv2.
        self.conv2 = GATv2Conv(8 * 8, out_channels, heads=1, concat=False, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_normal = GAT(dataset.num_node_features, dataset.num_classes).to(device)
model_orthogonal = GAT(dataset.num_node_features, dataset.num_classes).to(device)

data = dataset[0].to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)


In [5]:
# Assuming 'model' is your PyTorch model
for name, param in model_orthogonal.named_parameters():
#     print(f"Name: {name}, Value: {param}")
        print(f"Name: {name}, Value: {param.size()}")

Name: conv1.att, Value: torch.Size([1, 8, 8])
Name: conv1.bias, Value: torch.Size([64])
Name: conv1.lin_l.weight, Value: torch.Size([64, 1433])
Name: conv1.lin_l.bias, Value: torch.Size([64])
Name: conv1.lin_r.weight, Value: torch.Size([64, 1433])
Name: conv1.lin_r.bias, Value: torch.Size([64])
Name: conv2.att, Value: torch.Size([1, 1, 7])
Name: conv2.bias, Value: torch.Size([7])
Name: conv2.lin_l.weight, Value: torch.Size([7, 64])
Name: conv2.lin_l.bias, Value: torch.Size([7])
Name: conv2.lin_r.weight, Value: torch.Size([7, 64])
Name: conv2.lin_r.bias, Value: torch.Size([7])


In [6]:
from StiefelOptimizers import StiefelAdam, CombinedOptimizer

euclidean_parameters = []
stiefel_parameters = []

all_parameters = []

# put the Euclidean and Stiefel parameters into 2 different list
for name, param in model_orthogonal.named_parameters():
    # print(name)
    if name=='conv1.lin_l.weight' or name=='conv1.lin_r.weight':
        # torch.nn.init.orthogonal_(param) # optional
        stiefel_parameters.append(param)
    else:
        # print(param)
        euclidean_parameters.append(param)

# add all parameters to the all_parameters list:
for name, param in model_normal.named_parameters():
    all_parameters.append(param)


if len(euclidean_parameters)==0:
    optimizer_stiefel=StiefelAdam(stiefel_parameters, lr=0.001)
    optimizer_orthogonal=optimizer_stiefel
else:
    optimizer_euclidean=torch.optim.Adam(euclidean_parameters, lr=0.001)
    optimizer_stiefel=StiefelAdam(stiefel_parameters, lr=0.001)
    # combine the two optimizers
    optimizer_orthogonal=CombinedOptimizer(optimizer_euclidean, optimizer_stiefel)



optimizer_normal = torch.optim.Adam(all_parameters, lr=0.001)

In [7]:
from tqdm import tqdm

model_normal.train()
pbar = tqdm(range(200), desc='Training', unit='epoch')
for epoch in pbar:
    def closure():
        optimizer_normal.zero_grad()
        out = model_normal(data)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        return loss
    optimizer_normal.step(closure)
    pbar.set_postfix({'loss': closure().item()})


model_orthogonal.train()
pbar = tqdm(range(200), desc='Training', unit='epoch')
for epoch in pbar:
    def closure():
        optimizer_orthogonal.zero_grad()
        out = model_orthogonal(data)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        return loss
    optimizer_orthogonal.step(closure)
    pbar.set_postfix({'loss': closure().item()})

Training: 100%|██████████| 200/200 [00:43<00:00,  4.60epoch/s, loss=0.589]
Training: 100%|██████████| 200/200 [00:46<00:00,  4.31epoch/s, loss=0.901]


In [8]:
model_normal.eval()
_, pred = model_normal(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
accuracy = correct / int(data.test_mask.sum())
print(f'Accuracy, non-orthogonal learning: {accuracy:.4f}')


Accuracy, non-orthogonal learning: 0.7920


In [9]:
model_orthogonal.eval()
_, pred = model_orthogonal(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
accuracy = correct / int(data.test_mask.sum())
print(f'Accuracy, orthogonal learning: {accuracy:.4f}')

Accuracy, orthogonal learning: 0.8260
