In [40]:
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import torch
from pathlib import Path
import os
from torch_geometric.utils import erdos_renyi_graph

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphSAGE
from torch_geometric.data import Data, Dataset

In [4]:
MODELS_PATH = Path("./models")

if not MODELS_PATH.exists():
    os.makedirs(MODELS_PATH)

In [66]:
n_graphs = 1000
class RandomDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return ''
    
    @property
    def processed_file_names(self):
        l = []
        for i in range(n_graphs):
            l.append("data_" + str(i) + ".pt")
        return l

    def download(self):
        pass

    def process(self):
        idx = 0
        for g in range(n_graphs):                    
            x = torch.rand([torch.randint(low=20, high=50, size=[1]), 1], dtype=torch.float32)
            edge_index = erdos_renyi_graph(x.size()[0], 0.3)            
            data = Data(x=x, edge_index=edge_index.contiguous())

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [72]:
dataset = RandomDataset(root="./data/")

Processing...
Done!


In [56]:
class SAGEConvModel(torch.nn.Module):
    def __init__(self, hidden_channels=64, num_layers=2, out_channels=1):
        super(SAGEConvModel, self).__init__()
        self.sage = GraphSAGE(dataset.num_features, hidden_channels, num_layers, out_channels)

    def forward(self, x, edge_index):
        x = self.sage(x, edge_index)
        return torch.softmax(x, dim=1)

In [126]:
def loss_fn(batchX, randomFriend, randomEnemies, Q=1):
  fst = -torch.log(torch.sigmoid(torch.sum(batchX * randomFriend, dim=1)))
  snd = -Q * torch.mean(torch.log(torch.sigmoid(-torch.sum(randomEnemies * batchX,  dim=2))), dim=0)
  return torch.mean(fst+snd)

def get_random_repr(node, x, edge_dict):
    variants = edge_dict.get(node, [])
    if variants:
        return x[np.random.choice(variants)]
    return torch.zeros(n_features)

def get_dict_out_of_nodes(Nnodes, edge_index):
    edge_dict = {i:[] for i in range(Nnodes)}
    for edge in edge_index.reshape(-1,2).numpy():
        edge_dict[edge[0]].append(edge[1])
    edge_dict = dict(zip(*gData.edge_index.numpy()))
    return edge_dict

In [150]:
# criterion = torch.nn.CrossEntropyLoss()
from functools import partial
criterion = loss_fn

epochs = 200
n_models = 2
epsilon = 1e-5
Nenem = 10
n_features = 1

models_result = []

def train():
    for i in range(n_models):
        model = SAGEConvModel()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        train_loss = []
        for epoch in range(epochs):
            for g in dataset[0:10]:
                model.train()
                out = model(g.x, g.edge_index)

                idxperm = torch.randperm(g.x.size()[0])
                # pidxs  = idxperm[batch*Nbatch:(batch+1)*Nbatch]
                out = out[idxperm]
                edge_dict = dict(zip(*g.edge_index.numpy()))
                get_random_repr_p = partial(get_random_repr, x=out, edge_dict=edge_dict)
                randomFriend = torch.stack(list(map(get_random_repr_p, idxperm.numpy())), dim=0)
                randomEnemies = out[np.random.choice(g.x.size()[0], Nenem)].reshape(Nenem, 1, 1)
                
                loss = criterion(out, randomFriend, randomEnemies)            
                print(out.requires_grad)
                loss.backward()
                optimizer.step()
                
                train_loss.append(loss.item())
    
            if (epoch + 1) % 10 == 0:
                print(f'Model {i}: epoch: {epoch + 1:03d}, loss: {np.mean(train_loss)}')

        models_result.append(model(dataset[0].x, dataset[0].edge_index))
        for g in dataset[1:]:
            models_result[i] = torch.cat([models_result[i], model(g.x, g.edge_index)])
        torch.save(model.state_dict(), f'./models/{i+1}')


def eval_stability():
    disagr = 0
    for i in range(n_models):
        for j in range(i, n_models):   
            disagr += torch.sum(torch.abs(models_result[i] - models_result[j])) / models_result[i].size()[0]
    return disagr / (n_models * (n_models + 1) / 2)

In [151]:
train()
# print(eval_stability1())

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
Model 0: epoch: 010, loss: 1.6265232920646668
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True

KeyboardInterrupt: 

In [36]:
model.double()
for param in model.parameters():
    print(param.dtype)

torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64


In [148]:
eval_stability()

tensor(0., grad_fn=<DivBackward0>)