In [None]:
import random
import numpy as np
import pickle
from collections import deque, namedtuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GINConv, global_mean_pool

############################################
# Constants
############################################
t1, t2, t3, b = 0.1, 0.1, 10.0, 0.5
ACTION_COST = {
    1: 1*t1 + 4*t2 + b*t3,
    2: 1*t1 + 1*t2,
    3: 1*t1 + 1*t2,
    4: 1*t1 + 3*t2,
    5: 1*t3,
    6: 3*t2 + 1*t3
}

############################################
# Graph Environment
############################################
class GraphEnv:
    def __init__(self, A, T=None):
        self.reset(A, T)

    def reset(self, A, T=None):
        self.A = A.copy()
        self.n = A.shape[0]
        if T is None:
            self.T = np.full(self.n, -1, dtype=int)
        else:
            self.T = T.copy()
        return self.get_state()

    def get_neighbors(self, i):
        return set(np.where(self.A[i] == 1)[0])

    def is_done(self):
        return np.all(self.A == 0) and np.all(np.isin(self.T, [0, 1]))

    def delete_node(self, i):
        self.T[i] = 0
        self.A[i, :] = 0
        self.A[:, i] = 0

    def valid_actions(self):
        actions = []
        for i in range(self.n):
            Ni = self.get_neighbors(i)
            if self.T[i] == -1:
                actions.append((1, i, -1))
            if self.T[i] == 1 and len(Ni) == 1:
                j = next(iter(Ni))
                if self.T[j] == -1:
                    actions.append((2, i, j))
            if self.T[i] == -1 and len(Ni) == 1:
                j = next(iter(Ni))
                if self.T[j] == 1:
                    actions.append((3, i, j))
            for j in range(i + 1, self.n):
                if self.get_neighbors(i) == self.get_neighbors(j):
                    if self.T[i] == -1 and self.T[j] == 1:
                        actions.append((4, i, j))
                if self.T[i] == 1 and self.T[j] == 1:
                    if self.A[i, j] == 1:
                        actions.append((5, i, j))
                    if self.get_neighbors(i) == self.get_neighbors(j):
                        actions.append((6, i, j))
        return list(set(actions))

    def step(self, action):
        a, i, j = action
        cost = ACTION_COST[a]
        if a == 1:
            self.T[i] = 1
        elif a == 2:
            self.T[j] = 1
            self.delete_node(i)
        elif a == 3:
            self.delete_node(i)
        elif a == 4:
            self.delete_node(i)
        elif a == 5:
            self.A[i, j] = self.A[j, i] = 0
        elif a == 6:
            self.delete_node(i)
        reward = -cost
        return self.get_state(), reward, self.is_done()

    def get_state(self):
        return self.A.copy(), self.T.copy()

############################################
# GIN Encoder
############################################
class GINEncoder(nn.Module):
    def __init__(self, in_dim=3, hidden_dim=128):
        super().__init__()
        def mlp():
            return nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
        self.lin_in = nn.Linear(in_dim, hidden_dim)
        self.conv1 = GINConv(mlp())
        self.conv2 = GINConv(mlp())

    def forward(self, batch):
        x = self.lin_in(batch.x)
        x = self.conv1(x, batch.edge_index)
        x = self.conv2(x, batch.edge_index)
        graph_emb = global_mean_pool(x, batch.batch)
        return x, graph_emb

############################################
# Q Network
############################################
class QNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(128 + 262, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
    def forward(self, graph_emb, action_emb):
        return self.net(torch.cat([graph_emb, action_emb], dim=1))

############################################
# Replay Buffer
############################################
Transition = namedtuple("Transition", "state action reward next_state done")
class ReplayBuffer:
    def __init__(self, cap):
        self.buffer = deque(maxlen=cap)
    def push(self, *args):
        self.buffer.append(Transition(*args))
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    def __len__(self):
        return len(self.buffer)

############################################
# Utilities
############################################
def graph_to_data(A, T):
    x = torch.zeros((len(T), 3))
    for i, t in enumerate(T):
        x[i, t + 1] = 1
    edge_index = torch.tensor(np.array(np.where(A == 1)), dtype=torch.long)
    return Data(x=x, edge_index=edge_index)

def encode_action(action, node_emb, device):
    a, i, j = action
    vec = torch.zeros(262, device=device)
    vec[a - 1] = 1
    vec[6:6+128] = node_emb[i]
    if j != -1:
        vec[6+128:] = node_emb[j]
    return vec

############################################
# Training Loop
############################################
def train(envs, episodes=300):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = GINEncoder().to(device)
    qnet = QNetwork().to(device)
    target = QNetwork().to(device)
    target.load_state_dict(qnet.state_dict())
    optimizer = optim.Adam(list(encoder.parameters()) + list(qnet.parameters()), lr=1e-3)
    buffer = ReplayBuffer(10000)
    eps = 1.0
    gamma = 0.99
    step = 0

    for ep in range(episodes):
        env = random.choice(envs)
        state = env.reset(env.A, env.T)
        done = False
        while not done:
            A, T = state
            data = graph_to_data(A, T)
            batch = Batch.from_data_list([data]).to(device)
            node_emb, graph_emb = encoder(batch)
            actions = env.valid_actions()

            if random.random() < eps:
                action = random.choice(actions)
            else:
                qvals = [qnet(graph_emb, encode_action(a, node_emb, device).unsqueeze(0)).item() for a in actions]
                action = actions[np.argmax(qvals)]

            next_state, reward, done = env.step(action)
            buffer.push(state, action, reward, next_state, done)
            state = next_state

            # Training from buffer
            if len(buffer) >= 256:
                batch_tr = buffer.sample(256)
                graphs, actions_b, rewards, next_graphs, dones = zip(*batch_tr)
                data_list = [graph_to_data(*s) for s in graphs]
                batch_data = Batch.from_data_list(data_list).to(device)
                node_embs, graph_embs = encoder(batch_data)
                action_embs = torch.stack([encode_action(actions_b[i], node_embs, device) for i in range(256)])
                q = qnet(graph_embs, action_embs).squeeze()
                with torch.no_grad():
                    target_q = []
                    for i in range(256):
                        if dones[i]:
                            target_q.append(rewards[i])
                        else:
                            A2, T2 = next_graphs[i]
                            env2 = GraphEnv(A2, T2)
                            acts2 = env2.valid_actions()
                            d2 = graph_to_data(A2, T2)
                            b2 = Batch.from_data_list([d2]).to(device)
                            ne2, ge2 = encoder(b2)
                            maxq = max(target(ge2, encode_action(a2, ne2, device).unsqueeze(0)).item() for a2 in acts2)
                            target_q.append(rewards[i] + gamma * maxq)
                    target_q = torch.tensor(target_q, device=device)
                loss = nn.functional.mse_loss(q, target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            if step % 500 == 0:
                target.load_state_dict(qnet.state_dict())
            step += 1
        eps *= 0.99
        print(f"Episode {ep+1}/{episodes}, epsilon={eps:.3f}")
    return encoder, qnet

############################################
# Inference
############################################
def run_inference(env, encoder, qnet):
    device = next(qnet.parameters()).device
    state = env.get_state()
    total_reward = 0
    steps = 0
    while not env.is_done():
        A, T = state
        data = graph_to_data(A, T)
        batch = Batch.from_data_list([data]).to(device)
        node_emb, graph_emb = encoder(batch)
        actions = env.valid_actions()
        qvals = [qnet(graph_emb, encode_action(a, node_emb, device).unsqueeze(0)).item() for a in actions]
        action = actions[np.argmax(qvals)]
        state, reward, done = env.step(action)
        total_reward += reward
        steps += 1
    return {"done": env.is_done(), "steps": steps, "total_reward": total_reward, "final_T": env.T, "final_A": env.A}

In [None]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

# Load fixed graphs
with open("fixed_graph_dataset.pkl", "rb") as f:
    data = pickle.load(f)
TRAIN_GRAPHS = data["train"]
TEST_GRAPHS  = data["test"]

def make_envs(graphs):
    return [GraphEnv(A.copy(), T.copy()) for A, T in graphs]

# Training envs
train_envs = make_envs(TRAIN_GRAPHS)
print("Creating training graphs...")
for k, (A, _) in enumerate(TRAIN_GRAPHS):
    print(f"Train graph {k}: n={A.shape[0]}")

# Train DQN
print("\nTraining DQN...")
encoder, qnet = train(train_envs, episodes=300)

# Save checkpoint
torch.save({"encoder_state": encoder.state_dict(), "qnet_state": qnet.state_dict()}, "graph_dqn_checkpoint.pt")
print("Model saved to graph_dqn_checkpoint.pt")

In [None]:
# Load checkpoint for inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = GINEncoder().to(device)
qnet = QNetwork().to(device)
checkpoint = torch.load("graph_dqn_checkpoint.pt", map_location=device)
encoder.load_state_dict(checkpoint["encoder_state"])
qnet.load_state_dict(checkpoint["qnet_state"])
encoder.eval()
qnet.eval()
print("\nModel loaded successfully")

# Inference on all test graphs
print("\nRunning inference on ALL fixed test graphs...")
all_results = []

for idx, (A_test, T_test) in enumerate(TEST_GRAPHS):
    test_env = GraphEnv(A_test.copy(), T_test.copy())
    result = run_inference(test_env, encoder, qnet)
    all_results.append({
        "graph_id": idx,
        "solved": result["done"],
        "steps": result["steps"],
        "cost": -result["total_reward"]
    })
    print(f"\nTest Graph {idx}")
    print("  n:", A_test.shape[0])
    print("  Steps:", result["steps"])
    print("  Total reward:", result["total_reward"])
    print("  Final T:", result["final_T"])