In [5]:
import random
import numpy as np
from collections import deque, namedtuple

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch_geometric.data import Data, Batch
from torch_geometric.nn import GINConv, global_mean_pool

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

############################################
# Constants
############################################
t1, t2, t3, b = 0.1, 0.1, 10.0, 0.5

ACTION_COST = {
    1: 1*t1 + 4*t2 + b*t3,
    2: 1*t1 + 1*t2,
    3: 1*t1 + 1*t2,
    4: 1*t1 + 3*t2,
    5: 1*t3,
    6: 3*t2 + 1*t3
}

############################################
# Graph Environment
############################################
class GraphEnv:
    def __init__(self, A):
        self.reset(A)

    def reset(self, A):
        self.A = A.copy()
        self.n = A.shape[0]
        self.T = np.full(self.n, -1, dtype=int)
        return self.get_state()

    def get_neighbors(self, i):
        return set(np.where(self.A[i] == 1)[0])

    def is_done(self):
        return np.all(self.A == 0) and np.all(np.isin(self.T, [0, 1]))

    def delete_node(self, i):
        self.T[i] = 0
        self.A[i, :] = 0
        self.A[:, i] = 0

    def valid_actions(self):
        actions = []
        for i in range(self.n):
            Ni = self.get_neighbors(i)

            if self.T[i] == -1:
                actions.append((1, i, -1))

            if self.T[i] == 1 and len(Ni) == 1:
                j = next(iter(Ni))
                if self.T[j] == -1:
                    actions.append((2, i, j))

            if self.T[i] == -1 and len(Ni) == 1:
                j = next(iter(Ni))
                if self.T[j] == 1:
                    actions.append((3, i, j))

            for j in range(i + 1, self.n):
                if self.get_neighbors(i) == self.get_neighbors(j):
                    if self.T[i] == -1 and self.T[j] == 1:
                        actions.append((4, i, j))

                if self.T[i] == 1 and self.T[j] == 1:
                    if self.A[i, j] == 1:
                        actions.append((5, i, j))
                    if self.get_neighbors(i) == self.get_neighbors(j):
                        actions.append((6, i, j))

        return list(set(actions))

    def step(self, action):
        a, i, j = action
        cost = ACTION_COST[a]

        if a == 1:
            self.T[i] = 1
        elif a == 2:
            self.T[j] = 1
            self.delete_node(i)
        elif a == 3:
            self.delete_node(i)
        elif a == 4:
            self.delete_node(i)
        elif a == 5:
            self.A[i, j] = self.A[j, i] = 0
        elif a == 6:
            self.delete_node(i)

        reward = -cost
        return self.get_state(), reward, self.is_done()

    def get_state(self):
        return self.A.copy(), self.T.copy()

############################################
# GNN Encoder
############################################
class GINEncoder(nn.Module):
    def __init__(self, in_dim=3, hidden_dim=128):
        super().__init__()

        def mlp():
            return nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )

        self.lin_in = nn.Linear(in_dim, hidden_dim)
        self.conv1 = GINConv(mlp())
        self.conv2 = GINConv(mlp())

    def forward(self, batch):
        x = self.lin_in(batch.x)
        x = F.relu(self.conv1(x, batch.edge_index))
        x = F.relu(self.conv2(x, batch.edge_index))
        graph_emb = global_mean_pool(x, batch.batch)
        return x, graph_emb

############################################
# Q Network
############################################
class QNetwork(nn.Module):
    def __init__(self, action_dim=262, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, graph_emb, action_emb):
        return self.net(torch.cat([graph_emb, action_emb], dim=1))

############################################
# Replay Buffer
############################################
Transition = namedtuple("Transition", "state action reward next_state done")

class ReplayBuffer:
    def __init__(self, cap):
        self.buffer = deque(maxlen=cap)

    def push(self, *args):
        self.buffer.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

############################################
# Utilities
############################################
def graph_to_data(A, T):
    x = torch.zeros((len(T), 3))
    for i, t in enumerate(T):
        x[i, t + 1] = 1
    edge_index = torch.tensor(np.array(np.nonzero(A)), dtype=torch.long)
    return Data(x=x, edge_index=edge_index)

def encode_action(action, node_emb, device):
    a, i, j = action
    vec = torch.zeros(262, device=device)
    vec[a - 1] = 1
    vec[6:6+128] = node_emb[i]
    if j != -1:
        vec[6+128:] = node_emb[j]
    return vec

############################################
# DQN Train
############################################
def train_dqn(train_envs, episodes=300, batch_size=256):
    encoder = GINEncoder().to(device)
    qnet = QNetwork().to(device)
    target = QNetwork().to(device)
    target.load_state_dict(qnet.state_dict())

    optimizer = optim.Adam(list(encoder.parameters()) + list(qnet.parameters()), lr=1e-3)
    buffer = ReplayBuffer(10000)
    gamma = 0.99
    eps = 1.0
    step_count = 0

    for ep in range(episodes):
        env = random.choice(train_envs)
        state = env.reset(env.A)
        done = False

        while not done:
            A, T = state
            data = graph_to_data(A, T)
            batch = Batch.from_data_list([data]).to(device)
            node_emb, graph_emb = encoder(batch)
            actions = env.valid_actions()

            if random.random() < eps:
                action = random.choice(actions)
            else:
                qvals = [qnet(graph_emb, encode_action(a, node_emb, device).unsqueeze(0)).item() for a in actions]
                action = actions[np.argmax(qvals)]

            next_state, reward, done = env.step(action)
            buffer.push(state, action, reward, next_state, done)
            state = next_state

            if len(buffer) >= batch_size:
                batch_samples = buffer.sample(batch_size)
                states_b, actions_b, rewards_b, next_states_b, dones_b = zip(*batch_samples)

                # Encode current states
                data_list = [graph_to_data(*s) for s in states_b]
                batch_data = Batch.from_data_list(data_list).to(device)
                node_embs, graph_embs = encoder(batch_data)
                action_embs = torch.stack([encode_action(actions_b[i], node_embs, device) for i in range(batch_size)])
                qvals = qnet(graph_embs, action_embs).squeeze()

                # Compute targets
                targets = []
                for i in range(batch_size):
                    if dones_b[i]:
                        targets.append(rewards_b[i])
                    else:
                        A2, T2 = next_states_b[i]
                        env2 = GraphEnv(A2)
                        env2.T = T2.copy()
                        acts2 = env2.valid_actions()
                        d2 = graph_to_data(A2, T2)
                        b2 = Batch.from_data_list([d2]).to(device)
                        ne2, ge2 = encoder(b2)
                        maxq = max(target(ge2, encode_action(a2, ne2, device).unsqueeze(0)).item() for a2 in acts2)
                        targets.append(rewards_b[i] + gamma * maxq)
                targets = torch.tensor(targets, device=device)
                loss = F.mse_loss(qvals, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if step_count % 500 == 0:
                target.load_state_dict(qnet.state_dict())
            step_count += 1

        eps *= 0.99
        print(f"Episode {ep+1}/{episodes}, epsilon={eps:.3f}")

    return encoder, qnet

############################################
# DQN Inference
############################################
def infer_dqn(env, encoder, qnet):
    state = env.get_state()
    encoder.eval()
    qnet.eval()
    total_reward = 0

    while not env.is_done():
        A, T = state
        data = graph_to_data(A, T)
        batch = Batch.from_data_list([data]).to(device)
        node_emb, graph_emb = encoder(batch)
        actions = env.valid_actions()
        qvals = [qnet(graph_emb, encode_action(a, node_emb, device).unsqueeze(0)).item() for a in actions]
        action = actions[np.argmax(qvals)]
        state, reward, done = env.step(action)
        total_reward += reward

    return {
        "final_T": env.T,
        "final_A": env.A,
        "total_reward": total_reward,
        "done": env.is_done()
    }

import pickle


random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

# =========================
# 1. Load fixed graphs
# =========================
with open("fixed_graph_dataset.pkl", "rb") as f:
    data = pickle.load(f)

TRAIN_GRAPHS = data["train"]
TEST_GRAPHS  = data["test"]

train_envs = [GraphEnv(A.copy()) for A, _ in TRAIN_GRAPHS]

# Test on all test graphs
for idx, (A_test, _) in enumerate(TEST_GRAPHS):
    test_env = GraphEnv(A_test.copy())

print("Training DQN...")
encoder, qnet = train_dqn(train_envs, episodes=300)




Training DQN...
Episode 1/300, epsilon=0.990
Episode 2/300, epsilon=0.980
Episode 3/300, epsilon=0.970
Episode 4/300, epsilon=0.961
Episode 5/300, epsilon=0.951
Episode 6/300, epsilon=0.941
Episode 7/300, epsilon=0.932
Episode 8/300, epsilon=0.923
Episode 9/300, epsilon=0.914
Episode 10/300, epsilon=0.904
Episode 11/300, epsilon=0.895
Episode 12/300, epsilon=0.886
Episode 13/300, epsilon=0.878
Episode 14/300, epsilon=0.869
Episode 15/300, epsilon=0.860
Episode 16/300, epsilon=0.851
Episode 17/300, epsilon=0.843
Episode 18/300, epsilon=0.835
Episode 19/300, epsilon=0.826
Episode 20/300, epsilon=0.818
Episode 21/300, epsilon=0.810
Episode 22/300, epsilon=0.802
Episode 23/300, epsilon=0.794
Episode 24/300, epsilon=0.786
Episode 25/300, epsilon=0.778
Episode 26/300, epsilon=0.770
Episode 27/300, epsilon=0.762
Episode 28/300, epsilon=0.755
Episode 29/300, epsilon=0.747
Episode 30/300, epsilon=0.740
Episode 31/300, epsilon=0.732
Episode 32/300, epsilon=0.725
Episode 33/300, epsilon=0.718
Epi

In [8]:
# Test on all test graphs
print("\nRunning inference on test graphs...")
all_results = []

for idx, (A_test,_) in enumerate(TEST_GRAPHS):
    test_env = GraphEnv(A_test.copy())
    result = infer_dqn(test_env, encoder, qnet)
    all_results.append({
        "graph_id": idx,
        "solved": result["done"],
        "cost": -result["total_reward"],
        "final_T": result["final_T"]
    })

    print(f"\nTest Graph {idx}")
    print("  n:", A_test.shape[0])
    print("  Total reward:", result["total_reward"])
    print("  Final T:", result["final_T"])


Running inference on test graphs...

Test Graph 0
  n: 10
  Total reward: -123.80000000000001
  Final T: [1 0 1 0 1 0 1 0 1 1]

Test Graph 1
  n: 10
  Total reward: -174.6
  Final T: [1 0 1 1 0 1 1 1 1 1]

Test Graph 2
  n: 14
  Total reward: -326.4
  Final T: [1 1 1 1 1 0 1 1 0 1 1 1 1 1]

Test Graph 3
  n: 14
  Total reward: -225.8
  Final T: [1 1 1 1 1 1 1 0 1 0 1 0 0 1]

Test Graph 4
  n: 13
  Total reward: -275.59999999999997
  Final T: [1 1 1 1 1 1 1 0 1 0 1 0 1]
