In [1]:
import networkx as nx
import torch
from tqdm import trange, tqdm

from dql import DQGN, Agent
from graph import milp_solve_mds, prepare_graph, mds_is_solved

In [None]:
n, p = 100, .15
tt_g = 500
graphs = []

print(f'sampling and solving {tt_g} x G{n, p}')
for i in trange(tt_g, unit='graph'):
    g = prepare_graph(i, range(n, n+1), p, milp_solve_mds, g_nx=True)
    graphs.append(g)

sampling and solving 500 x G(100, 0.15)


 56%|█████▋    | 282/500 [14:30<09:02,  2.49s/graph]

In [None]:
ckp_path = './experiments/2024-12-10-2120/version_0/checkpoints/epoch=839-step=840.ckpt'
device = torch.device('cpu')
checkpoint = torch.load(ckp_path, map_location=device)
hyper_parameters = checkpoint['hyper_parameters']

state_dict = {k.replace('net.', ''): v for k, v in checkpoint['state_dict'].items() if 'target' not in k}
state_dict.pop('loss_module.pos_weight', None)

c_in = graphs[0].x.shape[1]
gnn = DQGN(c_in=c_in)
gnn.load_state_dict(state_dict)
gnn.eval()
gnn

In [None]:
agent = Agent(n, p, None, None, None, graphs)

In [None]:
valid_ds = []
size_eq_mlip = []
apx_ratio = []
gap = []

print(f'testing agent on mds')
for g in tqdm(graphs, unit='graph'):
    agent.reset(g)
    # Perform an episode of actions
    for step in range(n):
        rwd, done = agent.play_validation_step(gnn, 'cpu')
        if done:
            break
    g.s = {i for i, v in enumerate(agent.state.x) if v == 1}

    valid_ds.append(all(v in g.s or len(g.nx[v].keys() & g.s) > 0 for v in g.nx))
    len_sol = len([i for i in g.y if i])
    size_eq_mlip.append(len(g.s) <= len_sol)
    apx_ratio.append(len(g.s)/len_sol)
    gap.append(len_sol/len(g.s))

In [None]:
print(f'{100*sum(valid_ds)/tt_g}% valid DS      {100*sum(size_eq_mlip)/tt_g}% equivalent to MILP        {sum(apx_ratio)/tt_g:.3f} apx ratio         {1-sum(gap)/tt_g:.3f} gap')

In [None]:
g_gen = (g for g in graphs if len(g.s) == len([i for i in g.y if i]))

In [None]:
g = next(g_gen)
g_n = g.nx

try:
    layout = nx.planar_layout(g_n)
except:
    layout = nx.drawing.spring_layout(g_n)

node_colors = ['blue' if n in g.s else 'gray' for n in g.nx]
nx.draw(g_n, with_labels=True, node_color=node_colors, pos=layout)
y = {i for i, v in enumerate(g.y) if v}
print(f'{g.s=} {y=}')

In [None]:
edge_index, node_feats = g.edge_index, torch.FloatTensor([[0]] * n)
q_values = gnn(node_feats, edge_index).squeeze().tolist()
sorted(enumerate(q_values), key=lambda qv: qv[1], reverse=True)

In [None]:
edge_index, node_feats = g.edge_index, torch.FloatTensor([[1], [0], [1], [1], [0], [0], [0], [0], [0], [0]])
q_values = gnn(node_feats, edge_index).squeeze().tolist()
sorted(enumerate(q_values), key=lambda qv: qv[1], reverse=True)