In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random

from torch_geometric.data import Data, Batch
from torch_geometric.nn import GINConv, global_mean_pool

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# Hyperparameters
# =========================
GAMMA = 0.99
LAMBDA = 0.95
CLIP_EPS = 0.2
PPO_EPOCHS = 4
LR = 3e-4

# Cost constants
t1, t2, t3, b = 0.1, 0.1, 10.0, 0.5

# =========================
# Environment
# =========================
class GraphEnv:
    def __init__(self, A, T):
        self.A = A.copy()
        self.T = T.copy()
        self.n = len(T)

    def clone(self):
        return GraphEnv(self.A.copy(), self.T.copy())

    def done(self):
        return np.all(self.A == 0) and np.all(np.isin(self.T, [0, 1]))

    def get_neighbors(self, i):
        return set(np.where(self.A[i] == 1)[0])

    def delete(self, i):
        self.T[i] = 0
        self.A[i, :] = 0
        self.A[:, i] = 0

    def step(self, action):
        a, i, j = action
        cost = 0.0

        if a == 1:
            self.T[i] = 1
            cost = t1 + 4*t2 + b*t3
        elif a == 2:
            self.T[j] = 1
            self.delete(i)
            cost = t1 + t2
        elif a == 3:
            self.delete(i)
            cost = t1 + t2
        elif a == 4:
            self.delete(i)
            cost = t1 + 3*t2
        elif a == 5:
            self.A[i, j] = 0
            self.A[j, i] = 0
            cost = t3
        elif a == 6:
            self.delete(i)
            cost = 3*t2 + t3

        reward = -cost
        return reward, self.done()

# =========================
# Generate random connected graph
# =========================
def random_connected_graph(n, p=0.3):
    A = np.zeros((n, n), dtype=int)
    for i in range(n - 1):
        A[i, i + 1] = A[i + 1, i] = 1
    for i in range(n):
        for j in range(i + 1, n):
            if random.random() < p:
                A[i, j] = A[j, i] = 1
    T = np.full(n, -1)
    return A, T

# =========================
# Enumerate valid actions
# =========================
def enumerate_actions(env):
    actions = []
    n = env.n

    for i in range(n):
        Ni = env.get_neighbors(i)
        if env.T[i] == -1:
            actions.append((1, i, -1))
        if env.T[i] == 1 and len(Ni) == 1:
            j = next(iter(Ni))
            if env.T[j] == -1:
                actions.append((2, i, j))
        if env.T[i] == -1 and len(Ni) == 1:
            j = next(iter(Ni))
            if env.T[j] == 1:
                actions.append((3, i, j))

    for i in range(n):
        for j in range(i + 1, n):
            if env.get_neighbors(i) == env.get_neighbors(j):
                if env.T[i] == -1 and env.T[j] == 1:
                    actions.append((4, i, j))
                if env.T[i] == 1 and env.T[j] == 1:
                    actions.append((6, i, j))
            if env.T[i] == 1 and env.T[j] == 1 and env.A[i, j] == 1:
                actions.append((5, i, j))

    return list(set(actions))

# =========================
# GIN Encoder
# =========================
class GINEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        nn1 = nn.Sequential(nn.Linear(3,128), nn.ReLU(), nn.Linear(128,128))
        nn2 = nn.Sequential(nn.Linear(128,128), nn.ReLU(), nn.Linear(128,128))
        self.conv1 = GINConv(nn1)
        self.conv2 = GINConv(nn2)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        g = global_mean_pool(x, batch)
        return g, x

# =========================
# Policy & Value
# =========================
class PolicyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(128 + 6 + 128 + 128, 128),
            nn.ReLU(),
            nn.Linear(128,1)
        )

    def forward(self, g, ai):
        g = g.expand(ai.size(0), -1)
        return self.mlp(torch.cat([g, ai], dim=1)).squeeze(-1)

class ValueNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,1)
        )
    def forward(self, g):
        return self.mlp(g).squeeze(-1)

# =========================
# Graph & action encoding
# =========================
def encode_graph(A, T):
    edge_index = torch.tensor(np.array(np.nonzero(A)), dtype=torch.long)
    x = torch.zeros((len(T),3))
    for i, t in enumerate(T):
        x[i, t+1] = 1
    return Data(x=x, edge_index=edge_index)

def encode_action(action, node_emb):
    a, i, j = action
    a_onehot = F.one_hot(torch.tensor(a-1), 6).float()
    ei = node_emb[i]
    ej = node_emb[j] if j>=0 else torch.zeros_like(ei)
    return torch.cat([a_onehot, ei, ej])

# =========================
# PPO Training
# =========================
def train(envs, episodes=300):
    encoder = GINEncoder().to(device)
    policy = PolicyNet().to(device)
    value_net = ValueNet().to(device)

    optimizer = optim.Adam(
        list(encoder.parameters()) + list(policy.parameters()) + list(value_net.parameters()),
        lr=LR
    )

    for ep in range(episodes):
        memory = []

        # Rollout (no grad)
        for env in envs:
            env = env.clone()
            while not env.done():
                actions = enumerate_actions(env)
                if not actions:
                    break
                data = encode_graph(env.A, env.T)
                batch = Batch.from_data_list([data]).to(device)
                with torch.no_grad():
                    g, node_emb = encoder(batch)
                    action_embs = torch.stack([encode_action(a,node_emb) for a in actions]).to(device)
                    logits = policy(g, action_embs)
                    probs = F.softmax(logits, dim=0)
                    dist = torch.distributions.Categorical(probs)
                    idx = dist.sample()
                    action = actions[idx.item()]
                    reward, done = env.step(action)
                    memory.append({
                        "A": env.A.copy(),
                        "T": env.T.copy(),
                        "action": action,
                        "old_log_prob": dist.log_prob(idx).detach(),
                        "reward": reward
                    })

        # Compute returns
        returns = []
        G = 0
        for step in reversed(memory):
            G = step["reward"] + GAMMA*G
            returns.insert(0,G)
        returns = torch.tensor(returns, dtype=torch.float32).to(device)

        # PPO Update
        for _ in range(PPO_EPOCHS):
            total_loss = 0
            for step, R in zip(memory, returns):
                data = encode_graph(step["A"], step["T"])
                batch = Batch.from_data_list([data]).to(device)
                g, node_emb = encoder(batch)
                action_emb = encode_action(step["action"], node_emb).unsqueeze(0).to(device)
                logits = policy(g, action_emb)
                log_prob = F.log_softmax(logits, dim=0)[0]
                ratio = torch.exp(log_prob - step["old_log_prob"])
                advantage = R - value_net(g).detach()
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio,1-CLIP_EPS,1+CLIP_EPS)*advantage
                policy_loss = -torch.min(surr1,surr2)
                value_loss = F.mse_loss(value_net(g), R.unsqueeze(0))
                total_loss += policy_loss + value_loss
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
        print(f"Episode {ep} done")

    return encoder, policy, value_net

# =========================
# Inference
# =========================
@torch.no_grad()
def run_inference(env, encoder, policy, deterministic=True, max_steps=1000):
    encoder.eval()
    policy.eval()
    total_reward = 0
    steps = 0
    trajectory = []

    while not env.done() and steps < max_steps:
        actions = enumerate_actions(env)
        if not actions:
            break
        data = encode_graph(env.A, env.T)
        batch = Batch.from_data_list([data]).to(device)
        g, node_emb = encoder(batch)
        action_embs = torch.stack([encode_action(a,node_emb) for a in actions]).to(device)
        logits = policy(g, action_embs)
        probs = F.softmax(logits, dim=0)
        if deterministic:
            idx = torch.argmax(probs)
        else:
            dist = torch.distributions.Categorical(probs)
            idx = dist.sample()
        action = actions[idx.item()]
        reward, done = env.step(action)
        total_reward += reward
        trajectory.append((action,reward))
        steps += 1
    return {
        "total_reward": total_reward,
        "steps": steps,
        "trajectory": trajectory,
        "done": env.done()
    }

In [None]:
import pickle
import random
import numpy as np
import torch

# =========================
# Reproducibility
# =========================
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

# =========================
# Load fixed graphs
# =========================
with open("fixed_graph_dataset.pkl", "rb") as f:
    data = pickle.load(f)

TRAIN_GRAPHS = data["train"]
TEST_GRAPHS  = data["test"]

def make_envs(graphs):
    # IMPORTANT: copy so environments don't share state
    return [GraphEnv(A.copy(), T.copy()) for A, T in graphs]


# =========================
# Training graphs (FIXED)
# =========================
train_envs = make_envs(TRAIN_GRAPHS)

for k, (A, _) in enumerate(TRAIN_GRAPHS):
    print(f"Train graph {k}: n={A.shape[0]}")


# =========================
# Train Actor Critic
# =========================
print("\nTraining Actor Critic...")
encoder, policy, value = train(
    train_envs,
    episodes=300  # keep same as your original code
)

In [None]:
# =========================
# Inference on ALL test graphs
# =========================
print("\nRunning inference on ALL fixed test graphs...")

all_results = []

for idx, (A_test, T_test) in enumerate(TEST_GRAPHS):
    test_env = GraphEnv(A_test.copy(), T_test.copy())

    result = run_inference(
        test_env,
        encoder,
        policy,
        deterministic=True
    )

    all_results.append({
        "graph_id": idx,
        "solved": result["done"],
        "steps": result["steps"],
        "cost": -result["total_reward"]
    })

    print(f"\nTest Graph {idx}")
    print("  n:", A_test.shape[0])
    print("  Steps:", result["steps"])
    print("  Total reward:", result["total_reward"])
    print("  Final T:", test_env.T)