In [1]:
from pathlib import Path

import networkx as nx
import torch
import yaml
from tqdm import tqdm

from dql import DQNLightning
from graph import milp_solve_mds, is_ds, generate_graphs

In [2]:
base_path = '../experiments/2025-04-17-0752/version_0'
model_path = base_path + '/checkpoints/epoch=39-step=40.ckpt'
hparams_path = base_path + '/hparams.yaml'

In [3]:
dqn_model: DQNLightning = DQNLightning.load_from_checkpoint(model_path, map_location=torch.device("cpu"), hparams_file=hparams_path, s=1, warm_start_steps=0)

Sampling 1 instances from G(range(100, 101), 0.15)...


100%|██████████| 1/1 [00:01<00:00,  1.29s/graph]


In [4]:
valid_ds = []
size_eq_mlip = []

conf = yaml.safe_load(Path(hparams_path).read_text())
n, p = conf['n'], conf['p']
tt_g = 300
graphs = generate_graphs(range(n, conf['delta_n']+1), p, tt_g, milp_solve_mds, attrs=conf['graph_attr'])

sampling 300 x G(100, 0.15)


100%|██████████| 300/300 [02:01<00:00,  2.47it/s]


In [7]:
print(f'solving mds')
for g in tqdm(graphs):

    # Perform an episode of actions
    s = []
    dqn_model.agent.reset(g)
    for step in range(n):
        action = dqn_model.agent.get_action(dqn_model.net, 0, 'cpu')
        s.append(action)
        g.x = g.x.clone()
        g.x[action][0] = 1
        dqn_model.agent.state = g
        if is_ds(g.nx, s):
            break
    else:
        raise Exception('Could not find a DS')
    g.s = s
    breakpoint()

    valid_ds.append(all(v in s or len(g.nx[v].keys() & s) > 0 for v in g.nx))
    size_eq_mlip.append(len(s) <= len([i for i in g.y if i]))

solving mds


100%|██████████| 300/300 [00:44<00:00,  6.74it/s]


In [8]:
f'apx-ratio {sum(len(g.s)/(g.y == 1).sum() for g in graphs)/len(graphs):.2f} -- ' \
f'avg S_gnn {sum(len(g.s) for g in graphs)/len(graphs):.2f} -- ' \
f'avg S* {sum((g.y == 1).sum() for g in graphs)/len(graphs):.2f}'

'apx-ratio 2.48 -- avg S_gnn 22.70 -- avg S* 9.17'

In [None]:
print(f'{100*sum(valid_ds)/tt_g}% valid DS      {100*sum(size_eq_mlip)/tt_g}% equivalent to MILP')

In [None]:
print(conf)

In [None]:
graphs[0].s

In [None]:
for g in graphs:
    if len(g.s) == len([i for i in g.y if i]):
        break
g_n = g.nx

try:
    layout = nx.planar_layout(g_n)
except:
    layout = nx.drawing.spring_layout(g_n)

node_colors = ['blue' if n in g.s else 'gray' for n in g.nx]
nx.draw(g_n, with_labels=True, node_color=node_colors, pos=layout)
print(g.s)

In [None]:
edge_index, node_feats = g.edge_index, torch.FloatTensor([[1]] * n)
agent.gnn(node_feats, edge_index).squeeze()