In [1]:
from einops import rearrange, repeat
import torch
from torch import nn
from torch_geometric.nn import GATConv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
in_dim = 2
out_dim = 113
hidden_dim = 128
heads = 4
class GATV3(torch.nn.Module):
  def __init__(self, num_layers=1, dim_model=128, num_heads=4, num_tokens=114, seq_len=3):
    super().__init__()
    self.conv = GATConv(dim_model, dim_model, num_heads=heads).to(torch.float64)
    self.linear = nn.Linear(dim_model, num_tokens - 1).to(torch.float64) 

    self.token_embeddings = nn.Embedding(num_tokens, dim_model).to(torch.float64)  # We have p+1 input tokens: 0,1,...,113.
    self.position_embeddings = nn.Embedding(seq_len, dim_model).to(torch.float64)   # We length 3 sequences, e.g. (10, 25, 113)

  def forward(self, inputs):
    token_embedding = self.token_embeddings(inputs)
    
    positions = repeat(torch.arange(inputs.shape[1]), "p -> b p", b = inputs.shape[0])
    position_embedding = self.position_embeddings(positions)
    
    embedding = token_embedding + position_embedding

    #embedding = rearrange(embedding, 'b s d -> s b d')
    
    res = torch.Tensor([])
    for i in range(embedding.shape[0]):
        x = self.conv(embedding[i], edge_index)
        x = self.linear(x)
        x = torch.mean(x, dim=0)
        res = torch.cat([res, x.reshape(1, -1)], dim=0)
    return res

In [3]:
model = GATV3()

In [4]:
import networkx as nx
edge_index = torch.Tensor(nx.adjacency_matrix(nx.random_regular_graph(2,3)).todense()).nonzero().t().contiguous()

In [5]:
# Params for dataset
n = 2000
p = 0.001
prime = 113
threshold = 0.3
no_digits = 4

# Params for GNN
input_dim = 1
hidden_dim = 512
output_dim = 1

# Params for optimizer
weight_decay = 3
lr = 1e-3
betas = (0.9, 0.98)

In [6]:
lim = 30

In [7]:
node_feats = torch.Tensor([])
for i in range(lim):
    for j in range(lim):
        node_feats = torch.cat([node_feats, torch.Tensor([[i, j, 113]])], dim=0)
node_feats = node_feats.to(torch.int64)

node_labels = (node_feats[: ,0] + node_feats[: ,1]) % 113

In [8]:
import random
poss = [i for i in range(lim**2)]
idx = random.sample(poss, int(threshold * len(poss)))
train_mask = [True if i in idx else False for i in range(lim**2)]
val_mask = [False if train_mask[i] else True for i in range(lim**2)]

In [None]:
from sklearn.metrics import accuracy_score
model = GATV3()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)

criterion = nn.CrossEntropyLoss()
def train():
      model.train()
      optimizer.zero_grad()  # Clear gradients.
      out = model(node_feats)  # Perform a single forward pass.
      out = out.to(torch.float64)


      loss = criterion(out[train_mask], node_labels[train_mask])  # Compute the loss solely based on the training nodes.
      loss_test = criterion(out[val_mask], node_labels[val_mask])
      loss.backward()  # Derive gradients.
      optimizer.step()  # Update parameters based on gradients.
      return loss, loss_test

def test():
      model.eval()
      pred = model(node_feats)
      pred = pred.argmax(dim=1)  # Use the class with highest probability.
      #true = node_labels.argmax(dim=1)
    
#       test_acc = torch.sqrt(torch.mean((pred[val_mask] - node_labels[val_mask]) ** 2))
#       train_acc = torch.sqrt(torch.mean((pred[train_mask] - node_labels[train_mask]) ** 2))

    
#       return test_acc.item(), train_acc.item()
      return (accuracy_score(node_labels[val_mask], pred[val_mask]),
            accuracy_score(node_labels[train_mask], pred[train_mask]))

# def train():
#       model.train()
#       optimizer.zero_grad()  # Clear gradients.
#       out = model(node_feats, adj_matrix)  # Perform a single forward pass.
#       loss = criterion(out[train_mask], node_labels[train_mask])  # Compute the loss solely based on the training nodes.
#       loss_test = criterion(out[val_mask], node_labels[val_mask])
#       loss.backward()  # Derive gradients.
#       optimizer.step()  # Update parameters based on gradients.
#       return loss, loss_test

# def test():
#       model.eval()
#       out = model(node_feats, adj_matrix)
#       pred = out.argmax(dim=1)  # Use the class with highest probability.
#       true = node_labels.argmax(dim=1)
#       test_correct = pred[val_mask] == true[val_mask]  # Check against ground-truth labels.
#       test_acc = int(np.array(test_correct).sum()) / int(np.array(val_mask).sum())  # Derive ratio of correct predictions.


#       train_correct = pred[train_mask] == true[train_mask]  # Check against ground-truth labels.
#       train_acc = int(np.array(train_correct).sum()) / int(np.array(train_mask).sum())  # Derive ratio of correct predictions.
#       return test_acc, train_acc
import tqdm.auto as tqdm
train_loss = []
test_loss = []
test_aa = []
train_aa = []
for epoch in tqdm.tqdm(range(1000)):
#     if epoch % 100 == 0:
#       print(epoch)
    loss, loss_test = train()
    test_acc, train_acc = test()
    train_loss.append(loss.item())
    test_loss.append(loss_test.item())
    test_aa.append(test_acc)
    train_aa.append(train_acc)
    if epoch % 10 == 0:
        print('train test')
        print(loss)
        print(loss_test)
        print('----------')

  0%|                                          | 1/1000 [00:00<16:22,  1.02it/s]

train test
tensor(4.8870, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.8466, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  1%|▍                                        | 11/1000 [00:10<16:17,  1.01it/s]

train test
tensor(3.9065, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.3486, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  2%|▊                                        | 21/1000 [00:20<16:26,  1.01s/it]

train test
tensor(3.3496, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.3192, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  3%|█▎                                       | 31/1000 [00:30<16:05,  1.00it/s]

train test
tensor(3.0231, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.3390, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  4%|█▋                                       | 41/1000 [00:41<16:29,  1.03s/it]

train test
tensor(2.7989, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.4426, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  5%|██                                       | 51/1000 [00:51<16:07,  1.02s/it]

train test
tensor(2.6528, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.5814, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  6%|██▌                                      | 61/1000 [01:01<15:56,  1.02s/it]

train test
tensor(2.5512, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.7202, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  7%|██▉                                      | 71/1000 [01:11<15:45,  1.02s/it]

train test
tensor(2.4714, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(4.8624, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  8%|███▎                                     | 81/1000 [01:21<15:11,  1.01it/s]

train test
tensor(2.4054, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(5.0073, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


  9%|███▋                                     | 91/1000 [01:31<15:03,  1.01it/s]

train test
tensor(2.3459, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(5.1398, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 10%|████                                    | 101/1000 [01:41<15:06,  1.01s/it]

train test
tensor(2.2926, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(5.2815, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 11%|████▍                                   | 111/1000 [01:51<14:53,  1.01s/it]

train test
tensor(2.2453, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(5.4261, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 12%|████▊                                   | 121/1000 [02:01<14:53,  1.02s/it]

train test
tensor(2.2047, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(5.5773, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 13%|█████▏                                  | 131/1000 [02:11<14:26,  1.00it/s]

train test
tensor(2.1691, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(5.7130, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 14%|█████▋                                  | 141/1000 [02:21<14:51,  1.04s/it]

train test
tensor(2.1372, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(5.8413, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 15%|██████                                  | 151/1000 [02:31<14:09,  1.00s/it]

train test
tensor(2.1076, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(5.9616, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 16%|██████▍                                 | 161/1000 [02:41<14:04,  1.01s/it]

train test
tensor(2.0802, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(6.0735, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 17%|██████▊                                 | 171/1000 [02:51<13:47,  1.00it/s]

train test
tensor(2.0555, dtype=torch.float64, grad_fn=<NllLossBackward0>)
tensor(6.1743, dtype=torch.float64, grad_fn=<NllLossBackward0>)
----------


 18%|███████▏                                | 180/1000 [03:01<14:07,  1.03s/it]

In [234]:
rearrange(torch.randn(2,3,4), 'a b c -> a c b')

tensor([[[-0.6565, -0.1264, -0.4744],
         [-0.7860, -0.2021,  1.3956],
         [-0.3140, -0.5562,  0.3549],
         [-0.2646, -1.6739,  0.7716]],

        [[-0.3312,  0.6003,  0.3360],
         [ 0.4989, -0.6734,  1.2114],
         [ 1.4119,  0.3367,  0.4190],
         [-2.5377,  0.8328, -2.8228]]])