In [2]:
%load_ext autoreload
%autoreload 2

In [4]:
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

In [36]:
from envs.GraphEnv.impnode import ImpnodeEnv

seed = 5545

In [37]:
env = ImpnodeEnv(8,2,3,seed)
env.action_space.seed(seed)

[5545]

In [38]:
for i in range(1):
    terminated = False
    #num_nodes = env.num_nodes()
    print(env.update_mask())
    while terminated is False:
        print(env.observation_space.shape[0])
        print(env.action_space.n)

        action = env.action_space.sample(mask=env.update_mask())
        print(action)



        observation, reward, terminated, truncated, info = env.step(action)
        print(observation)
        print(env.update_mask())
    env.reset()

[1 1 1 1 1 1 1 1]
8
8
2
Graph with 8 nodes and 12 edges
[1 1 0 1 1 1 1 1]
8
8
4
Graph with 8 nodes and 12 edges
[1 1 0 1 0 1 1 1]
8
8
3
Graph with 8 nodes and 12 edges
[1 1 0 0 0 1 1 1]


## torch geometric ##

In [52]:
graph = nx.barabasi_albert_graph(8,2)
nx.set_node_attributes(graph, np.ones(5, dtype=int), 'features')

In [53]:
x = torch.Tensor(list(nx.get_node_attributes(graph, "features").values()))
edge_index = torch.Tensor([list(e) for e in graph.edges]).long().T


In [54]:
x

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

In [55]:
edge_index.shape

torch.Size([2, 12])

In [56]:
from torch.nn import Linear, ReLU
from torch_geometric.nn import Sequential, GCNConv

in_channels = 5
out_channels = len(graph.nodes)
model = Sequential('x, edge_index', [
    (GCNConv(in_channels, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (GCNConv(64, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    Linear(64, out_channels),
])

In [57]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Sequential('x, edge_index', [
                        (GCNConv(5, 64), 'x, edge_index -> x'),
                        ReLU(inplace=True),
                        (GCNConv(64, 64), 'x, edge_index -> x'),
                        ReLU(inplace=True),
                        Linear(64, 8),
                    ])

    def forward(self, x, edge_index):

        x = self.model(x, edge_index)

        return x

In [58]:
model = GCN()

In [59]:
out = model(x, edge_index)

In [60]:
out.shape

torch.Size([8, 8])

In [134]:
edge_index.shape

torch.Size([2, 12])

In [135]:
class SumAgg(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')

    def forward(self, x, edge_index):
        out = self.propagate(edge_index, x=x)
        return out

    def message(self, x_j):
        return x_j

In [136]:

class CustomGNN(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers):
        super(CustomGNN, self).__init__()

        self.linear1 = nn.Linear(in_channels, out_channels)
        self.linear2 = nn.Linear(out_channels, out_channels // 2)
        self.linear3 = nn.Linear(out_channels, out_channels // 2)
        self.sum_agg = SumAgg()

    def forward(self, x, edge_index):

        x = F.relu(self.linear1(x))
        x = x / x.norm(dim=-1, keepdim=True)

        for _ in range(num_layers):

            neighbor_messages = self.sum_agg(x,edge_index)

            x = F.relu(torch.cat([self.linear2(x),self.linear3(neighbor_messages)], dim=-1))
            x = x / x.norm(dim=-1, keepdim=True)

        return x


In [137]:
in_channels = 5
out_channels = len(graph.nodes)
num_layers = 2

model = CustomGNN(in_channels, out_channels, num_layers)

In [138]:
out = model(x, edge_index)

tensor([[0.5332, 0.4966, 0.0000, 0.0000, 0.1524, 0.4732, 0.4710, 0.0000],
        [0.5332, 0.4966, 0.0000, 0.0000, 0.1524, 0.4732, 0.4710, 0.0000],
        [0.5332, 0.4966, 0.0000, 0.0000, 0.1524, 0.4732, 0.4710, 0.0000],
        [0.5332, 0.4966, 0.0000, 0.0000, 0.1524, 0.4732, 0.4710, 0.0000],
        [0.5332, 0.4966, 0.0000, 0.0000, 0.1524, 0.4732, 0.4710, 0.0000],
        [0.5332, 0.4966, 0.0000, 0.0000, 0.1524, 0.4732, 0.4710, 0.0000],
        [0.5332, 0.4966, 0.0000, 0.0000, 0.1524, 0.4732, 0.4710, 0.0000],
        [0.5332, 0.4966, 0.0000, 0.0000, 0.1524, 0.4732, 0.4710, 0.0000]],
       grad_fn=<DivBackward0>)
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5332, 0.4966, 0.0000, 0.0000, 0.1524, 0.4732, 0.4710, 0.0000],
        [0.5332, 0.4966, 0.0000, 0.0000, 0.1524, 0.4732, 0.4710, 0.0000],
        [1.0665, 0.9933, 0.0000, 0.0000, 0.3049, 0.9465, 0.9420, 0.0000],
        [1.0665, 0.9933, 0.0000, 0.0000, 0.3049, 0.9465, 0.9420, 0.0000],
      

In [146]:
graph = nx.barabasi_albert_graph(8,2)
nx.set_node_attributes(graph, np.ones(5, dtype=int), 'features')

In [147]:
graph.edges

EdgeView([(0, 1), (0, 2), (0, 5), (1, 3), (1, 4), (1, 5), (1, 7), (2, 3), (3, 4), (3, 6), (4, 7), (5, 6)])

In [148]:
[graph.remove_edge(*i) for i in graph.edges if i[0] == 1 or i[1] == 1]

[None, None, None, None, None]

In [149]:
graph.edges

EdgeView([(0, 2), (0, 5), (2, 3), (3, 4), (3, 6), (4, 7), (5, 6)])