In [None]:
import pickle
import random
import numpy as np
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim

# =========================
# Reproducibility
# =========================
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# Graph Environment
# =========================
class GraphEnv:
    def __init__(self, A, T):
        self.A = A.copy()
        self.T = T.copy()
        self.n = len(T)
        self.done = False

    def clone(self):
        return GraphEnv(self.A.copy(), self.T.copy())

    def get_neighbors(self, i):
        return set(np.where(self.A[i] == 1)[0])

    def delete_node(self, i):
        self.A[i, :] = 0
        self.A[:, i] = 0
        self.T[i] = 0

    def action_cost(self, a):
        t1, t2, t3, b = 0.1, 0.1, 10.0, 0.5
        return {
            1: t1 + 4*t2 + b*t3,
            2: t1 + t2,
            3: t1 + t2,
            4: t1 + 3*t2,
            5: t3,
            6: 3*t2 + t3
        }[a]

    def apply_action(self, action):
        a, nodes = action
        cost = self.action_cost(a)
        if a == 1:
            self.T[nodes] = 1
        elif a == 2:
            i, j = nodes
            self.T[j] = 1
            self.delete_node(i)
        elif a in {3,4,6}:
            i, _ = nodes
            self.delete_node(i)
        elif a == 5:
            i, j = nodes
            self.A[i,j] = self.A[j,i] = 0
        self.done = (np.all(self.T != -1) or np.all(self.A.sum(axis=0)==0) or not self.get_valid_actions())
        return -cost, self.done

    def get_valid_actions(self):
        acts = []
        for i in range(self.n):
            if self.T[i] == -1: acts.append((1, i))
            if self.T[i] == 1:
                nbrs = self.get_neighbors(i)
                if len(nbrs)==1:
                    j = next(iter(nbrs))
                    if self.T[j]==-1: acts.append((2,(i,j)))
            if self.T[i] == -1:
                nbrs = self.get_neighbors(i)
                if len(nbrs)==1:
                    j = next(iter(nbrs))
                    if self.T[j]==1: acts.append((3,(i,j)))
        for i in range(self.n):
            for j in range(i+1,self.n):
                if self.T[i]==-1 and self.T[j]==1 and self.get_neighbors(i)==self.get_neighbors(j):
                    acts.append((4,(i,j)))
                if self.T[i]==1 and self.T[j]==1:
                    if self.A[i,j]==1: acts.append((5,(i,j)))
                    if self.get_neighbors(i)==self.get_neighbors(j): acts.append((6,(i,j)))
        return acts

    def get_state(self):
        return np.concatenate([self.A.flatten(), self.T])

# =========================
# Action encoding
# =========================
def encode_action(action, n):
    a, nodes = action
    v = np.zeros(2*n + 6)
    v[a-1] = 1
    if isinstance(nodes,int):
        v[6+nodes] = 1
    else:
        i,j = nodes
        v[6+i] = 1
        v[6+n+j] = 1
    return v

# =========================
# DQN Model
# =========================
class DQN(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Linear(128,1)
        )
    def forward(self,x):
        return self.net(x)

# =========================
# Training
# =========================
def train_dqn(envs, episodes=300, batch_size=64, gamma=0.99, lr=1e-3):
    n = envs[0].n
    input_dim = n**2 + n + 2*n + 6

    q_net = DQN(input_dim).to(device)
    target_net = DQN(input_dim).to(device)
    target_net.load_state_dict(q_net.state_dict())
    optimizer = optim.Adam(q_net.parameters(), lr=lr)
    memory = deque(maxlen=20000)
    criterion = nn.MSELoss()

    eps_start, eps_end, eps_decay = 1.0, 0.05, 0.992

    for ep in range(episodes):
        env = random.choice(envs).clone()
        state = env.get_state()
        eps = max(eps_end, eps_start*(eps_decay**ep))
        done=False
        while not done:
            acts = env.get_valid_actions()
            if not acts: break
            if random.random()<eps:
                action = random.choice(acts)
            else:
                q_net.eval()
                with torch.no_grad():
                    st_tensor = torch.tensor(state,dtype=torch.float32).repeat(len(acts),1)
                    ac_tensor = torch.tensor([encode_action(a,n) for a in acts],dtype=torch.float32)
                    inputs = torch.cat([st_tensor, ac_tensor],dim=1).to(device)
                    qs = q_net(inputs)
                    action = acts[torch.argmax(qs).item()]
                q_net.train()
            reward, done = env.apply_action(action)
            next_state = env.get_state()
            memory.append((state, action, reward, next_state, env.get_valid_actions(), done))
            state = next_state

            # Train
            if len(memory)>=batch_size:
                batch = random.sample(memory,batch_size)
                s_b, a_b, r_b, ns_b, na_b, d_b = zip(*batch)
                s_t = torch.tensor(np.stack(s_b),dtype=torch.float32)
                a_t = torch.tensor(np.stack([encode_action(a,n) for a in a_b]),dtype=torch.float32)
                curr_q = q_net(torch.cat([s_t,a_t],dim=1).to(device)).squeeze()
                target_q = torch.zeros(batch_size,device=device)
                with torch.no_grad():
                    for i in range(batch_size):
                        if d_b[i] or not na_b[i]:
                            target_q[i]=r_b[i]
                        else:
                            ns_rep = torch.tensor(ns_b[i],dtype=torch.float32).repeat(len(na_b[i]),1)
                            na_enc = torch.tensor([encode_action(a2,n) for a2 in na_b[i]],dtype=torch.float32)
                            next_inputs = torch.cat([ns_rep,na_enc],dim=1).to(device)
                            target_q[i] = r_b[i] + gamma*torch.max(target_net(next_inputs))
                loss = criterion(curr_q,target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        if ep%10==0: target_net.load_state_dict(q_net.state_dict())
        print(f"Episode {ep+1}/{episodes}, epsilon={eps:.3f}")
    return q_net

# =========================
# Inference
# =========================
def run_inference(env, q_net):
    env_copy = env.clone()
    q_net.eval()
    total_reward, steps = 0, 0
    while True:
        acts = env_copy.get_valid_actions()
        if not acts: break
        with torch.no_grad():
            st_tensor = torch.tensor(env_copy.get_state(),dtype=torch.float32).repeat(len(acts),1)
            ac_tensor = torch.tensor([encode_action(a,env_copy.n) for a in acts],dtype=torch.float32)
            inputs = torch.cat([st_tensor,ac_tensor],dim=1).to(device)
            best_action = acts[torch.argmax(q_net(inputs)).item()]
        reward, done = env_copy.apply_action(best_action)
        total_reward += reward
        steps += 1
        if done: break
    solved = np.all(env_copy.T != -1) or np.all(env_copy.A.sum(axis=0)==0)
    return {"solved":solved,"steps":steps,"total_reward":total_reward,"final_T":env_copy.T,"final_A":env_copy.A}




In [None]:
with open("fixed_graph_dataset.pkl","rb") as f:
    data = pickle.load(f)
TRAIN_GRAPHS = data["train"]
TEST_GRAPHS = data["test"]

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

train_envs = [GraphEnv(A.copy(),T.copy()) for A,T in TRAIN_GRAPHS]
for k,(A,_) in enumerate(TRAIN_GRAPHS):
    print(f"Train graph {k}: n={A.shape[0]}")

print("\nTraining DQN...")
qnet = train_dqn(train_envs,episodes=300)

In [None]:
print("\nRunning inference on all test graphs...")
for idx,(A_test,T_test) in enumerate(TEST_GRAPHS):
    test_env = GraphEnv(A_test.copy(),T_test.copy())
    res = run_inference(test_env,qnet)
    print(f"\nTest Graph {idx}")
    print("  n:", A_test.shape[0])
    print("  Steps:", res["steps"])
    print("  Total reward:", res["total_reward"])
    print("  Final T:", res["final_T"])