In [235]:
from einops import rearrange, repeat
import torch
from torch import nn
from torch_geometric.nn import GATConv

In [236]:
in_dim = 2
out_dim = 113
hidden_dim = 128
heads = 4
class GATV3(torch.nn.Module):
  def __init__(self, num_layers=1, dim_model=128, num_heads=4, num_tokens=114, seq_len=3):
    super().__init__()
    self.conv = GATConv(dim_model, dim_model, num_heads=heads).to(torch.float64)
    self.linear = nn.Linear(dim_model, num_tokens - 1).to(torch.float64) 

    self.token_embeddings = nn.Embedding(num_tokens, dim_model).to(torch.float64)  # We have p+1 input tokens: 0,1,...,113.
    self.position_embeddings = nn.Embedding(seq_len, dim_model).to(torch.float64)   # We length 3 sequences, e.g. (10, 25, 113)

  def forward(self, inputs):
    token_embedding = self.token_embeddings(inputs)
    
    positions = repeat(torch.arange(inputs.shape[1]), "p -> b p", b = inputs.shape[0])
    position_embedding = self.position_embeddings(positions)
    
    embedding = token_embedding + position_embedding

    #embedding = rearrange(embedding, 'b s d -> s b d')
    
    res = torch.Tensor([])
    for i in range(embedding.shape[0]):
        x = self.conv(embedding[i], edge_index)
        x = self.linear(x)
        x = torch.mean(x, dim=0)
        res = torch.cat([res, x.reshape(1, -1)], dim=0)
    return res

In [237]:
model = GATV3()

In [238]:
import networkx as nx
edge_index = torch.Tensor(nx.adjacency_matrix(nx.random_regular_graph(2,3)).todense()).nonzero().t().contiguous()

In [252]:
# Params for dataset
n = 2000
p = 0.001
prime = 113
threshold = 0.8
no_digits = 4

# Params for GNN
input_dim = 1
hidden_dim = 512
output_dim = 1

# Params for optimizer
weight_decay = 3
lr = 1e-3
betas = (0.9, 0.98)

In [253]:
lim = 30

In [254]:
node_feats = torch.Tensor([])
for i in range(lim):
    for j in range(lim):
        node_feats = torch.cat([node_feats, torch.Tensor([[i, j, 113]])], dim=0)
node_feats = node_feats.to(torch.int64)

node_labels = (node_feats[: ,0] + node_feats[: ,1]) % 113

In [255]:
import random
poss = [i for i in range(lim**2)]
idx = random.sample(poss, int(threshold * len(poss)))
train_mask = [True if i in idx else False for i in range(lim**2)]
val_mask = [False if train_mask[i] else True for i in range(lim**2)]

In [None]:
from sklearn.metrics import accuracy_score
model = GATV3()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)

criterion = nn.CrossEntropyLoss()
def train():
      model.train()
      optimizer.zero_grad()  # Clear gradients.
      out = model(node_feats)  # Perform a single forward pass.
      out = out.to(torch.float64)


      loss = criterion(out[train_mask], node_labels[train_mask])  # Compute the loss solely based on the training nodes.
      loss_test = criterion(out[val_mask], node_labels[val_mask])
      loss.backward()  # Derive gradients.
      optimizer.step()  # Update parameters based on gradients.
      return loss, loss_test

def test():
      model.eval()
      pred = model(node_feats)
      pred = pred.argmax(dim=1)  # Use the class with highest probability.
      #true = node_labels.argmax(dim=1)
    
#       test_acc = torch.sqrt(torch.mean((pred[val_mask] - node_labels[val_mask]) ** 2))
#       train_acc = torch.sqrt(torch.mean((pred[train_mask] - node_labels[train_mask]) ** 2))

    
#       return test_acc.item(), train_acc.item()
      return (accuracy_score(node_labels[val_mask], pred[val_mask]),
            accuracy_score(node_labels[train_mask], pred[train_mask]))

# def train():
#       model.train()
#       optimizer.zero_grad()  # Clear gradients.
#       out = model(node_feats, adj_matrix)  # Perform a single forward pass.
#       loss = criterion(out[train_mask], node_labels[train_mask])  # Compute the loss solely based on the training nodes.
#       loss_test = criterion(out[val_mask], node_labels[val_mask])
#       loss.backward()  # Derive gradients.
#       optimizer.step()  # Update parameters based on gradients.
#       return loss, loss_test

# def test():
#       model.eval()
#       out = model(node_feats, adj_matrix)
#       pred = out.argmax(dim=1)  # Use the class with highest probability.
#       true = node_labels.argmax(dim=1)
#       test_correct = pred[val_mask] == true[val_mask]  # Check against ground-truth labels.
#       test_acc = int(np.array(test_correct).sum()) / int(np.array(val_mask).sum())  # Derive ratio of correct predictions.


#       train_correct = pred[train_mask] == true[train_mask]  # Check against ground-truth labels.
#       train_acc = int(np.array(train_correct).sum()) / int(np.array(train_mask).sum())  # Derive ratio of correct predictions.
#       return test_acc, train_acc
import tqdm.auto as tqdm
train_loss = []
test_loss = []
test_aa = []
train_aa = []
for epoch in tqdm.tqdm(range(1000)):
#     if epoch % 100 == 0:
#       print(epoch)
    loss, loss_test = train()
    test_acc, train_acc = test()
    train_loss.append(loss.item())
    test_loss.append(loss_test.item())
    test_aa.append(test_acc)
    train_aa.append(train_acc)
    if epoch % 10 == 0:
        print('train test')
        print(loss)
        print(loss_test)
        print('----------')

  0%|                                          | 1/1000 [00:01<17:16,  1.04s/it]

train test
tensor(4.9347, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.9640, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  1%|▍                                        | 11/1000 [00:11<17:05,  1.04s/it]

train test
tensor(4.0690, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2685, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  2%|▊                                        | 21/1000 [00:21<17:07,  1.05s/it]

train test
tensor(3.7626, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.1873, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  3%|█▎                                       | 31/1000 [00:32<17:10,  1.06s/it]

train test
tensor(3.5485, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.1870, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  4%|█▋                                       | 41/1000 [00:43<16:54,  1.06s/it]

train test
tensor(3.3799, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.1856, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  5%|██                                       | 51/1000 [00:53<16:51,  1.07s/it]

train test
tensor(3.2434, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.1793, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  6%|██▌                                      | 61/1000 [01:04<16:35,  1.06s/it]

train test
tensor(3.1371, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.1875, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  7%|██▉                                      | 71/1000 [01:14<15:57,  1.03s/it]

train test
tensor(3.0558, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.1891, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  8%|███▎                                     | 81/1000 [01:25<15:49,  1.03s/it]

train test
tensor(2.9926, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.1938, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  9%|███▋                                     | 91/1000 [01:35<16:03,  1.06s/it]

train test
tensor(2.9420, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2020, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 10%|████                                    | 101/1000 [01:46<15:39,  1.04s/it]

train test
tensor(2.8999, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2146, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 11%|████▍                                   | 111/1000 [01:56<15:19,  1.03s/it]

train test
tensor(2.8631, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2304, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 12%|████▊                                   | 121/1000 [02:07<15:27,  1.06s/it]

train test
tensor(2.8295, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2484, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 13%|█████▏                                  | 131/1000 [02:17<15:10,  1.05s/it]

train test
tensor(2.7977, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2629, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 14%|█████▋                                  | 141/1000 [02:28<15:10,  1.06s/it]

train test
tensor(2.7666, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2683, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 15%|██████                                  | 151/1000 [02:38<14:45,  1.04s/it]

train test
tensor(2.7359, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2628, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 16%|██████▍                                 | 161/1000 [02:49<14:49,  1.06s/it]

train test
tensor(2.7044, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2553, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 17%|██████▊                                 | 171/1000 [02:59<14:22,  1.04s/it]

train test
tensor(2.6745, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2529, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 18%|███████▏                                | 181/1000 [03:09<14:08,  1.04s/it]

train test
tensor(2.6394, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.2344, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 19%|███████▋                                | 191/1000 [03:20<14:13,  1.06s/it]

train test
tensor(2.5880, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.1872, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 20%|████████                                | 201/1000 [03:30<14:00,  1.05s/it]

train test
tensor(2.5206, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.0937, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 21%|████████▍                               | 211/1000 [03:41<13:59,  1.06s/it]

train test
tensor(2.4548, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.0502, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 22%|████████▊                               | 221/1000 [03:52<13:29,  1.04s/it]

train test
tensor(2.3937, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.9982, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 23%|█████████▏                              | 231/1000 [04:02<13:24,  1.05s/it]

train test
tensor(2.3316, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.9433, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 24%|█████████▋                              | 241/1000 [04:13<13:31,  1.07s/it]

train test
tensor(2.2649, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.8573, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 25%|██████████                              | 251/1000 [04:23<13:07,  1.05s/it]

train test
tensor(2.2068, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.7837, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 26%|██████████▍                             | 261/1000 [04:34<12:38,  1.03s/it]

train test
tensor(2.1524, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.7107, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 27%|██████████▊                             | 271/1000 [04:44<12:45,  1.05s/it]

train test
tensor(2.1045, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.6388, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 28%|███████████▏                            | 281/1000 [04:54<12:12,  1.02s/it]

train test
tensor(2.0610, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.5694, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 29%|███████████▋                            | 291/1000 [05:05<12:14,  1.04s/it]

train test
tensor(2.0220, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.5114, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 30%|████████████                            | 301/1000 [05:15<12:18,  1.06s/it]

train test
tensor(1.9868, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.4623, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 31%|████████████▍                           | 311/1000 [05:26<11:55,  1.04s/it]

train test
tensor(1.9572, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.4259, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 32%|████████████▊                           | 321/1000 [05:36<11:34,  1.02s/it]

train test
tensor(1.9286, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.3928, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 33%|█████████████▏                          | 331/1000 [05:47<11:47,  1.06s/it]

train test
tensor(1.9034, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.3704, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 34%|█████████████▋                          | 341/1000 [05:57<11:21,  1.03s/it]

train test
tensor(1.8815, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.3523, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 35%|██████████████                          | 351/1000 [06:07<10:51,  1.00s/it]

train test
tensor(1.8583, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.3344, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 36%|██████████████▍                         | 361/1000 [06:18<11:20,  1.07s/it]

train test
tensor(1.8381, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.3211, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 37%|██████████████▊                         | 371/1000 [06:29<11:13,  1.07s/it]

train test
tensor(1.8182, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.3110, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 38%|███████████████▏                        | 381/1000 [06:39<10:42,  1.04s/it]

train test
tensor(1.8001, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.3013, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 39%|███████████████▋                        | 391/1000 [06:50<10:58,  1.08s/it]

train test
tensor(1.7866, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.2953, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 40%|████████████████                        | 401/1000 [07:00<10:19,  1.03s/it]

train test
tensor(1.7708, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.2841, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 41%|████████████████▍                       | 411/1000 [07:11<10:25,  1.06s/it]

train test
tensor(1.7565, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.2785, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 42%|████████████████▊                       | 421/1000 [07:22<10:25,  1.08s/it]

train test
tensor(1.7413, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.2728, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 43%|█████████████████▏                      | 431/1000 [07:32<09:49,  1.04s/it]

train test
tensor(1.7312, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.2657, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 44%|█████████████████▋                      | 441/1000 [07:42<09:46,  1.05s/it]

train test
tensor(1.7198, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.2601, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 45%|██████████████████                      | 451/1000 [07:53<09:46,  1.07s/it]

train test
tensor(1.7087, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.2492, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 46%|██████████████████▍                     | 461/1000 [08:03<09:25,  1.05s/it]

train test
tensor(1.6977, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.2441, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 47%|██████████████████▊                     | 471/1000 [08:14<09:13,  1.05s/it]

train test
tensor(1.6857, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.2414, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 48%|███████████████████▏                    | 481/1000 [08:25<09:12,  1.06s/it]

train test
tensor(1.6792, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(3.2384, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 49%|███████████████████▌                    | 488/1000 [08:32<08:59,  1.05s/it]

In [234]:
rearrange(torch.randn(2,3,4), 'a b c -> a c b')

tensor([[[-0.6565, -0.1264, -0.4744],
         [-0.7860, -0.2021,  1.3956],
         [-0.3140, -0.5562,  0.3549],
         [-0.2646, -1.6739,  0.7716]],

        [[-0.3312,  0.6003,  0.3360],
         [ 0.4989, -0.6734,  1.2114],
         [ 1.4119,  0.3367,  0.4190],
         [-2.5377,  0.8328, -2.8228]]])