In [15]:
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GINConv, global_mean_pool

# =========================
# Reproducibility
# =========================
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# Cost constants
# =========================
t1, t2, t3, b = 0.1, 0.1, 10.0, 0.5

# =========================
# Graph Environment
# =========================
class GraphEnv:
    def __init__(self, A, T):
        self.A = A.copy()
        self.T = T.copy()
        self.n = len(T)

    def clone(self):
        return GraphEnv(self.A.copy(), self.T.copy())

    def done(self):
        return np.all(self.A == 0) and np.all(np.isin(self.T, [0,1]))

    def get_neighbors(self, i):
        return set(np.where(self.A[i] == 1)[0])

    def delete(self, i):
        self.T[i] = 0
        self.A[i,:] = 0
        self.A[:,i] = 0

    def step(self, action):
        a, i, j = action
        cost = 0.0
        if a == 1:
            self.T[i] = 1
            cost = t1 + 4*t2 + b*t3
        elif a == 2:
            self.T[j] = 1
            self.delete(i)
            cost = t1 + t2
        elif a == 3:
            self.delete(i)
            cost = t1 + t2
        elif a == 4:
            self.delete(i)
            cost = t1 + 3*t2
        elif a == 5:
            self.A[i,j] = self.A[j,i] = 0
            cost = t3
        elif a == 6:
            self.delete(i)
            cost = 3*t2 + t3
        reward = -cost
        return reward, self.done()

# =========================
# Enumerate valid actions
# =========================
def enumerate_actions(env):
    actions = []
    n = env.n
    for i in range(n):
        Ni = env.get_neighbors(i)
        if env.T[i] == -1:
            actions.append((1,i,-1))
        if env.T[i] == 1 and len(Ni) == 1:
            j = next(iter(Ni))
            if env.T[j] == -1:
                actions.append((2,i,j))
        if env.T[i] == -1 and len(Ni) == 1:
            j = next(iter(Ni))
            if env.T[j] == 1:
                actions.append((3,i,j))
    for i in range(n):
        for j in range(i+1,n):
            if env.get_neighbors(i) == env.get_neighbors(j):
                if env.T[i] == -1 and env.T[j] == 1:
                    actions.append((4,i,j))
                if env.T[i] == 1 and env.T[j] == 1:
                    actions.append((6,i,j))
            if env.T[i]==1 and env.T[j]==1 and env.A[i,j]==1:
                actions.append((5,i,j))
    return list(set(actions))

# =========================
# GIN Encoder
# =========================
class GINEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        nn1 = nn.Sequential(nn.Linear(3,128), nn.ReLU(), nn.Linear(128,128))
        nn2 = nn.Sequential(nn.Linear(128,128), nn.ReLU(), nn.Linear(128,128))
        self.conv1 = GINConv(nn1)
        self.conv2 = GINConv(nn2)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        g = global_mean_pool(x, batch)
        return g, x

# =========================
# Actor & Critic
# =========================
class PolicyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(6 + 128 + 128 + 128, 128),
            nn.ReLU(),
            nn.Linear(128,1)
        )
    def forward(self, g, ai):
        g = g.expand(ai.size(0), -1)
        return self.mlp(torch.cat([g, ai], dim=1)).squeeze(-1)

class ValueNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(128,128), nn.ReLU(), nn.Linear(128,1))
    def forward(self,g):
        return self.mlp(g).squeeze(-1)

# =========================
# Graph & action encoding
# =========================
def encode_graph(A,T):
    edge_index = torch.tensor(np.array(np.nonzero(A)), dtype=torch.long)
    x = torch.zeros((len(T),3))
    for i,t in enumerate(T):
        x[i,t+1] = 1
    return Data(x=x, edge_index=edge_index)

def encode_action(action, node_emb):
    a,i,j = action
    a_onehot = F.one_hot(torch.tensor(a-1),6).float()
    ei = node_emb[i]
    ej = node_emb[j] if j>=0 else torch.zeros_like(ei)
    return torch.cat([a_onehot,ei,ej])

# =========================
# Actor-Critic Training
# =========================
def train(envs, episodes=300):
    # --- Separate encoders for actor and critic ---
    encoder_actor = GINEncoder().to(device)
    encoder_critic = GINEncoder().to(device)
    actor = PolicyNet().to(device)
    critic = ValueNet().to(device)

    optimizer_actor = optim.Adam(list(encoder_actor.parameters()) + list(actor.parameters()), lr=3e-4)
    optimizer_critic = optim.Adam(list(encoder_critic.parameters()) + list(critic.parameters()), lr=3e-4)

    GAMMA = 0.99

    for ep in range(episodes):
        for env in envs:
            env = env.clone()
            while not env.done():
                actions = enumerate_actions(env)
                if not actions:
                    break

                data = encode_graph(env.A, env.T)
                batch = Batch.from_data_list([data]).to(device)

                # --- Critic forward ---
                g_critic, node_emb_critic = encoder_critic(batch)
                value = critic(g_critic)

                # --- Actor forward ---
                g_actor, node_emb_actor = encoder_actor(batch)
                action_embs = torch.stack([encode_action(a,node_emb_actor) for a in actions]).to(device)
                logits = actor(g_actor, action_embs)
                probs = F.softmax(logits, dim=0)
                dist = torch.distributions.Categorical(probs)
                idx = dist.sample()
                action = actions[idx.item()]

                reward, done = env.step(action)
                advantage = reward + (0 if done else GAMMA*critic(g_critic)) - value

                # Critic update
                critic_loss = advantage.pow(2)
                optimizer_critic.zero_grad()
                critic_loss.backward()
                optimizer_critic.step()

                # Actor update
                actor_loss = -dist.log_prob(idx) * advantage.detach()
                optimizer_actor.zero_grad()
                actor_loss.backward()
                optimizer_actor.step()

        print(f"Episode {ep+1}/{episodes} done")

    return encoder_actor, actor, encoder_critic, critic

# =========================
# Inference
# =========================
@torch.no_grad()
def run_inference(env, encoder, actor, deterministic=True, max_steps=1000):
    encoder.eval()
    actor.eval()
    total_reward = 0
    steps = 0
    trajectory = []

    while not env.done() and steps < max_steps:
        actions = enumerate_actions(env)
        if not actions:
            break

        data = encode_graph(env.A, env.T)
        batch = Batch.from_data_list([data]).to(device)
        g, node_emb = encoder(batch)
        action_embs = torch.stack([encode_action(a,node_emb) for a in actions]).to(device)
        logits = actor(g, action_embs)
        probs = F.softmax(logits, dim=0)

        if deterministic:
            idx = torch.argmax(probs)
        else:
            dist = torch.distributions.Categorical(probs)
            idx = dist.sample()

        action = actions[idx.item()]
        reward, done = env.step(action)
        total_reward += reward
        trajectory.append((action,reward))
        steps += 1

    return {
        "total_reward": total_reward,
        "steps": steps,
        "trajectory": trajectory,
        "done": env.done()
    }


In [16]:
import pickle
import random
import numpy as np
import torch

# =========================
# Reproducibility
# =========================
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

# =========================
# Load fixed graphs
# =========================
with open("fixed_graph_dataset.pkl", "rb") as f:
    data = pickle.load(f)

TRAIN_GRAPHS = data["train"]
TEST_GRAPHS  = data["test"]

def make_envs(graphs):
    # IMPORTANT: copy so environments don't share state
    return [GraphEnv(A.copy(), T.copy()) for A, T in graphs]


# =========================
# Training graphs (FIXED)
# =========================
train_envs = make_envs(TRAIN_GRAPHS)

for k, (A, _) in enumerate(TRAIN_GRAPHS):
    print(f"Train graph {k}: n={A.shape[0]}")


# =========================
# Train Actor Critic
# =========================
print("\nTraining Actor Critic...")
encoder_actor, actor, encoder_critic, critic = train(
    train_envs,
    episodes=300
)


Train graph 0: n=11
Train graph 1: n=12
Train graph 2: n=8
Train graph 3: n=9
Train graph 4: n=12

Training Actor Critic...
Episode 1/300 done
Episode 2/300 done
Episode 3/300 done
Episode 4/300 done
Episode 5/300 done
Episode 6/300 done
Episode 7/300 done
Episode 8/300 done
Episode 9/300 done
Episode 10/300 done
Episode 11/300 done
Episode 12/300 done
Episode 13/300 done
Episode 14/300 done
Episode 15/300 done
Episode 16/300 done
Episode 17/300 done
Episode 18/300 done
Episode 19/300 done
Episode 20/300 done
Episode 21/300 done
Episode 22/300 done
Episode 23/300 done
Episode 24/300 done
Episode 25/300 done
Episode 26/300 done
Episode 27/300 done
Episode 28/300 done
Episode 29/300 done
Episode 30/300 done
Episode 31/300 done
Episode 32/300 done
Episode 33/300 done
Episode 34/300 done
Episode 35/300 done
Episode 36/300 done
Episode 37/300 done
Episode 38/300 done
Episode 39/300 done
Episode 40/300 done
Episode 41/300 done
Episode 42/300 done
Episode 43/300 done
Episode 44/300 done
Episo

In [17]:
# =========================
# Inference on ALL test graphs
# =========================
print("\nRunning inference on ALL fixed test graphs...")

all_results = []

for idx, (A_test, T_test) in enumerate(TEST_GRAPHS):
    test_env = GraphEnv(A_test.copy(), T_test.copy())

    # Use encoder_actor and actor for inference
    result = run_inference(
        test_env,
        encoder_actor,
        actor,
        deterministic=True
    )

    all_results.append({
        "graph_id": idx,
        "solved": result["done"],
        "steps": result["steps"],
        "cost": -result["total_reward"]
    })

    print(f"\nTest Graph {idx}")
    print("  n:", A_test.shape[0])
    print("  Steps:", result["steps"])
    print("  Total reward:", result["total_reward"])
    print("  Final T:", test_env.T)



Running inference on ALL fixed test graphs...

Test Graph 0
  n: 10
  Steps: 23
  Total reward: -185.0
  Final T: [1 1 1 1 1 1 1 1 1 1]

Test Graph 1
  n: 10
  Steps: 27
  Total reward: -225.60000000000002
  Final T: [1 0 1 0 1 1 1 1 1 1]

Test Graph 2
  n: 14
  Steps: 40
  Total reward: -338.50000000000006
  Final T: [1 0 0 1 0 0 0 1 1 1 1 1 1 1]

Test Graph 3
  n: 14
  Steps: 35
  Total reward: -288.20000000000005
  Final T: [1 0 1 1 0 0 1 1 0 1 1 1 1 1]

Test Graph 4
  n: 13
  Steps: 38
  Total reward: -322.70000000000005
  Final T: [1 0 1 1 0 0 0 1 1 1 1 1 1]
