In [18]:
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj
import torch
import numpy as np

# batch (Nbatch, Nfeatrs), randomFriend of the shape (Nbatch, Nfeatrs)
# randomEnemies of the shape (Nenem, Nbatch, Nfeatrs)
def loss_fn(batchX, randomFriend, randomEnemies, Q=1):
  fst = -torch.log(torch.sigmoid(torch.sum(batchX*randomFriend, dim=1)))
  snd = -Q*torch.mean(torch.log(torch.sigmoid(-torch.sum(randomEnemies * batchX,  dim=2))), dim=0)
  return torch.mean(fst+snd)

def _test_loss_fn(batchX, randomFriend, randomEnemies):
  fst = -torch.log(torch.sigmoid(torch.sum(batchX*randomFriend, dim=1))) # it is ok
  snd = []
  for i in range(randomEnemies.shape[0]):
    snd.append(-torch.log(torch.sigmoid(-torch.sum(batchX*randomEnemies[i], dim=1))))
  snd = torch.mean(torch.stack(snd), dim=0)
  return torch.mean(fst+snd)

In [290]:
Nbatch  = 100
Nfeatrs = 10
Nenem   = 10
batch = torch.rand((Nbatch, Nfeatrs))
randomFriend = torch.rand((Nbatch, Nfeatrs))
randomEnemies = torch.rand((Nenem, Nbatch, Nfeatrs))

torch.allclose(_test_loss_fn(batch, randomFriend, randomEnemies), 
               loss_fn(batch, randomFriend, randomEnemies)
               )
loss_fn(batch, randomFriend, randomEnemies)

tensor(2.7442)

In [198]:
np.random.seed(42)
Nnodes = 100
Nfeatrs = 10
connctns = torch.tensor(np.random.choice(Nnodes, size=(2, Nnodes), replace=True))
connctns = torch.concat([connctns, connctns[[1,0]]], dim=1)

X = np.random.random((Nnodes, Nfeatrs))
X = torch.tensor(X)
gData = Data(X, connctns)

# make dict out of nodes:
edge_dict = {i:[] for i in range(Nnodes)}
for edge in connctns.reshape(-1,2).numpy():
  edge_dict[edge[0]].append(edge[1])

In [296]:
def get_random_repr(node):
  variants = edge_dict.get(node, [])
  if variants:
    return X[np.random.choice(variants)]
  return torch.zeros(Nfeatrs)

Nepochs = 10
Nbatch  = Nnodes//10
for epoch in range(Nepochs):
  idxperm = torch.randperm(Nnodes)
  for batch in range(Nnodes//Nbatch):
    pidxs  = idxperm[batch*Nbatch:(batch+1)*Nbatch]
    batchX = X[pidxs]

    # get randomFriend and randomEnemies:
    randomFriend = torch.stack(list(map(get_random_repr, pidxs.numpy())), dim=0)
    randomEnemies = X[np.random.choice(Nnodes, Nenem*Nbatch)].reshape(Nenem, Nbatch, Nfeatrs)

    loss = loss_fn(batchX, randomFriend, randomEnemies)

In [299]:
A = np.random.rand(100*100,100)
B = A.reshape((100, 100, 100))
A[1]==B[0][1]

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True])