In [7]:
import pickle
import random
import numpy as np

# =========================
# Reproducibility
# =========================
random.seed(0)
np.random.seed(0)

# =========================
# Graph Environment (same as DQN)
# =========================
class GraphEnv:
    def __init__(self, A, T):
        self.A = A.copy()
        self.T = T.copy()
        self.n = len(T)

    def clone(self):
        return GraphEnv(self.A.copy(), self.T.copy())

    def get_neighbors(self,i):
        return set(np.where(self.A[i]==1)[0])

    def delete_node(self,i):
        self.A[i,:]=0
        self.A[:,i]=0
        self.T[i]=0

    def action_cost(self,a):
        t1,t2,t3,b = 0.1,0.1,10.0,0.5
        return {1:t1+4*t2+b*t3,2:t1+t2,3:t1+t2,4:t1+3*t2,5:t3,6:3*t2+t3}[a]

    def apply_action(self,action):
        a,nodes = action
        if a==1: self.T[nodes]=1
        elif a==2: i,j=nodes; self.T[j]=1; self.delete_node(i)
        elif a in {3,4,6}: i,_=nodes; self.delete_node(i)
        elif a==5: i,j=nodes; self.A[i,j]=self.A[j,i]=0
        return -self.action_cost(a)

    def get_valid_actions(self):
        acts=[]
        for i in range(self.n):
            if self.T[i]==-1: acts.append((1,i))
            if self.T[i]==1:
                nbrs=self.get_neighbors(i)
                if len(nbrs)==1:
                    j=next(iter(nbrs))
                    if self.T[j]==-1: acts.append((2,(i,j)))
            if self.T[i]==-1:
                nbrs=self.get_neighbors(i)
                if len(nbrs)==1:
                    j=next(iter(nbrs))
                    if self.T[j]==1: acts.append((3,(i,j)))
        for i in range(self.n):
            for j in range(i+1,self.n):
                if self.T[i]==-1 and self.T[j]==1 and self.get_neighbors(i)==self.get_neighbors(j):
                    acts.append((4,(i,j)))
                if self.T[i]==1 and self.T[j]==1:
                    if self.A[i,j]==1: acts.append((5,(i,j)))
                    if self.get_neighbors(i)==self.get_neighbors(j): acts.append((6,(i,j)))
        return acts

# =========================
# Greedy solver
# =========================
def greedy_solve(env):
    env_copy = env.clone()
    steps, total_cost = 0,0
    while True:
        acts = env_copy.get_valid_actions()
        if not acts: break
        best_action = min(acts,key=lambda act: env_copy.action_cost(act[0]))
        total_cost += env_copy.action_cost(best_action[0])
        env_copy.apply_action(best_action)
        steps+=1
    solved = np.all(env_copy.T != -1) or np.all(env_copy.A.sum(axis=0)==0)
    return {"solved":solved,"steps":steps,"total_cost":total_cost,"final_T":env_copy.T,"final_A":env_copy.A}

In [8]:
with open("fixed_graph_dataset.pkl","rb") as f:
    data = pickle.load(f)
TEST_GRAPHS = data["test"]


print("\nRunning greedy inference on all test graphs...")
for idx,(A_test,T_test) in enumerate(TEST_GRAPHS):
    test_env = GraphEnv(A_test.copy(),T_test.copy())
    res = greedy_solve(test_env)
    print(f"\nTest Graph {idx}")
    print("  n:", A_test.shape[0])
    print("  Steps:", res["steps"])
    print("  Total cost:", res["total_cost"])
    print("  Final T:", res["final_T"])


Running greedy inference on all test graphs...

Test Graph 0
  n: 10
  Steps: 24
  Total cost: 175.30000000000004
  Final T: [0 0 0 0 0 0 0 0 0 1]

Test Graph 1
  n: 10
  Steps: 35
  Total cost: 302.1000000000001
  Final T: [0 0 0 0 0 0 0 0 0 1]

Test Graph 2
  n: 14
  Steps: 55
  Total cost: 490.90000000000015
  Final T: [0 0 0 0 0 0 0 0 0 0 0 0 0 1]

Test Graph 3
  n: 14
  Steps: 42
  Total cost: 344.10000000000014
  Final T: [0 0 0 0 0 0 0 0 0 0 0 0 0 1]

Test Graph 4
  n: 13
  Steps: 48
  Total cost: 419.5000000000001
  Final T: [0 0 0 0 0 0 0 0 0 0 0 0 1]
