In [100]:
import networkx as nx
import seaborn as sns
from pathlib import Path
import torch
import DQN_agent_new
from envs.GraphEnv.impnode import ImpnodeEnv
import  numpy as np
from model import DQNNet
from replay_memory import ReplayMemory
from torch_geometric.utils import get_laplacian
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [101]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [117]:
subdir = 'data/synthetic/uniform_cost/6-10'
data_path = Path.cwd()/subdir

In [118]:
env = ImpnodeEnv(anc='dw_nd', ba_nodes=(3, 5), ba_edges = 4,max_removed_nodes = 4, seed=1234, render_option=False, data= False,data_path=data_path, train_mode=True)

In [119]:
memory = ReplayMemory(capacity=4)

In [120]:
for _ in range(10):
    N_STEP = 2
    state_history, action_history, reward_history = [], [], []
    done = False
    state, info = env.reset()

    while not done:
        action = env.action_space.sample(mask=info['node_action_mask']) # samples random action
        next_state, reward, done, truncated, info = env.step(action)
        state_history.append(state)
        action_history.append(action)
        reward_history.append(reward)

        if len(state_history) >= N_STEP:
            n_step_states = state_history[-N_STEP]
            n_step_actions = action_history[-N_STEP]
            n_step_rewards = reward_history[-N_STEP:]

            # Calculate n-step return
            n_step_return = sum(reward * (0.99 ** i) for i, reward in enumerate(n_step_rewards))
            memory.store(
                state=n_step_states,
                action=n_step_actions,
                next_state=next_state,
                reward=n_step_return,
                done=done
            )
        state = next_state

In [121]:
states, actions, next_states, rewards, dones = memory.sample(4, device)

In [122]:
policy_net = DQNNet(5, 2, 0.01).to(device)

In [123]:
import torch_geometric

pyg_states_novir = [torch_geometric.utils.from_networkx(state) for state in states]
graph = torch_geometric.data.Batch.from_data_list(pyg_states_novir)

new_states = []
# add virtual node, it's edges and it's features for states
for state in states:
    new_state = state.to_directed()
    new_node = len(new_state)
    new_state.add_node(new_node)
    # Add directed edges from the new node to all existing nodes
    for node in state.nodes:
        new_state.add_edge(new_node, node)
    nx.set_node_attributes(new_state, {new_node: np.ones(5, dtype=int)}, name='features')

    new_states.append(new_state)

# nx to pyg graph conversion for states
pyg_states = [torch_geometric.utils.from_networkx(state) for state in new_states]
batch_of_states = torch_geometric.data.Batch.from_data_list(pyg_states)


In [124]:
all_q_values_policy, embeddings = policy_net.forward(batch_of_states, embedding=True)

In [125]:
embeddings.shape

torch.Size([34, 4])

In [126]:
graph

DataBatch(edge_index=[2, 94], features=[34, 5], num_nodes=34, batch=[34], ptr=[5])

In [143]:
embed = [embeddings[graph.batch == i] for i in range(graph.batch_size)]

laplacians = [torch.sparse_coo_tensor(lap[0], lap[1], (n,n)).to_dense() for lap,n in zip([get_laplacian(graph[i].edge_index)  for i in range(graph.batch_size)],[graph[i].num_nodes for i in range(graph.batch_size)])]

loss_vals = [torch.trace(torch.matmul(torch.transpose(e,0,1), torch.matmul(l, e))) for l,e in zip(laplacians,embed)]

loss = sum(loss_vals)/graph.edge_index.size(1)
print(loss)

[tensor([[ 6., -1., -1., -1., -1., -1., -1.,  0.],
        [-1.,  3.,  0.,  0.,  0., -1., -1.,  0.],
        [-1.,  0.,  2.,  0.,  0., -1.,  0.,  0.],
        [-1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.],
        [-1.,  0.,  0.,  0.,  3., -1., -1.,  0.],
        [-1., -1., -1.,  0., -1.,  5., -1.,  0.],
        [-1., -1.,  0.,  0., -1., -1.,  4.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]), tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  3.,  0.,  0.,  0., -1.,  0.,  0., -1., -1.],
        [ 0.,  0.,  2.,  0.,  0., -1.,  0.,  0.,  0., -1.],
        [ 0.,  0.,  0.,  1.,  0.,  0.,  0., -1.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  2., -1.,  0., -1.,  0.,  0.],
        [ 0., -1., -1.,  0., -1.,  3.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0., -1., -1.,  0.,  0.,  4., -1., -1.],
        [ 0., -1.,  0.,  0.,  0.,  0.,  0., -1.,  2.,  0.],
        [ 0., -1., -1.,  0.,  0.,  0.,  0., -1.,  0.,  3.