In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import numpy as np

from torch_geometric.data import Data, Batch
from torch_geometric.nn import GINConv, global_mean_pool

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# Constants
# =========================
t1, t2, t3, b = 0.1, 0.1, 10.0, 0.5
GAMMA = 0.99
LAMBDA = 0.95
CLIP_EPS = 0.2
PPO_EPOCHS = 4
BATCH_SIZE = 256
LR = 3e-4

# =========================
# Environment
# =========================
class GraphEnv:
    def __init__(self, A, T):
        self.A = A.copy()
        self.T = T.copy()
        self.n = len(T)

    def clone(self):
        return GraphEnv(self.A.copy(), self.T.copy())

    def done(self):
        return np.all(self.A == 0) and np.all(np.isin(self.T, [0, 1]))

    def get_neighbors(self, i):
        return set(np.where(self.A[i] == 1)[0])

    def delete(self, i):
        self.T[i] = 0
        self.A[i, :] = 0
        self.A[:, i] = 0

    def step(self, action):
        a, i, j = action
        cost = 0.0

        if a == 1:
            self.T[i] = 1
            cost = t1 + 4*t2 + b*t3
        elif a == 2:
            self.T[j] = 1
            self.delete(i)
            cost = t1 + t2
        elif a == 3:
            self.delete(i)
            cost = t1 + t2
        elif a == 4:
            self.delete(i)
            cost = t1 + 3*t2
        elif a == 5:
            self.A[i, j] = 0
            self.A[j, i] = 0
            cost = t3
        elif a == 6:
            self.delete(i)
            cost = 3*t2 + t3

        reward = -cost
        return reward, self.done()

# =========================
# Action Enumeration
# =========================
def enumerate_actions(env):
    actions = []
    n = env.n

    for i in range(n):
        Ni = env.get_neighbors(i)

        if env.T[i] == -1:
            actions.append((1, i, -1))

        if env.T[i] == 1 and len(Ni) == 1:
            j = next(iter(Ni))
            if env.T[j] == -1:
                actions.append((2, i, j))

        if env.T[i] == -1 and len(Ni) == 1:
            j = next(iter(Ni))
            if env.T[j] == 1:
                actions.append((3, i, j))

    for i in range(n):
        for j in range(i+1, n):
            if env.get_neighbors(i) == env.get_neighbors(j):
                if env.T[i] == -1 and env.T[j] == 1:
                    actions.append((4, i, j))
                if env.T[i] == 1 and env.T[j] == 1:
                    actions.append((6, i, j))
            if env.T[i] == 1 and env.T[j] == 1 and env.A[i, j] == 1:
                actions.append((5, i, j))

    return list(set(actions))

# =========================
# GNN Encoder (GIN)
# =========================
class GINEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        nn1 = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 128))
        nn2 = nn.Sequential(nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 128))
        self.conv1 = GINConv(nn1)
        self.conv2 = GINConv(nn2)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        g = global_mean_pool(x, batch)
        return g, x

# =========================
# Policy + Value Networks
# =========================
class PolicyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(128 + 6 + 128 + 128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, g, ai):
        g = g.expand(ai.size(0), -1)
        return self.mlp(torch.cat([g, ai], dim=1)).squeeze(-1)

class ValueNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, g):
        return self.mlp(g).squeeze(-1)

# =========================
# Utilities
# =========================
def encode_graph(A, T):
    edge_index = torch.tensor(np.array(np.nonzero(A)), dtype=torch.long)
    x = torch.zeros((len(T), 3))
    for i, t in enumerate(T):
        x[i, t+1] = 1
    return Data(x=x, edge_index=edge_index)

def encode_action(action, node_emb):
    a, i, j = action
    a_onehot = F.one_hot(torch.tensor(a-1), 6).float()
    ei = node_emb[i]
    ej = node_emb[j] if j >= 0 else torch.zeros_like(ei)
    return torch.cat([a_onehot, ei, ej])

# =========================
# PPO Training Loop
# =========================
def train(envs, episodes=300):
    encoder = GINEncoder().to(device)
    policy = PolicyNet().to(device)
    value_net = ValueNet().to(device)

    optimizer = optim.Adam(
        list(encoder.parameters()) +
        list(policy.parameters()) +
        list(value_net.parameters()),
        lr=3e-4
    )

    for ep in range(episodes):
        memory = []

        # =========================
        # Rollout (NO GRAD)
        # =========================
        with torch.no_grad():
            for env in envs:
                env = env.clone()

                while not env.done():
                    actions = enumerate_actions(env)
                    if not actions:
                        break

                    data = encode_graph(env.A, env.T)
                    batch = Batch.from_data_list([data]).to(device)

                    g, node_emb = encoder(batch)

                    action_embs = torch.stack(
                        [encode_action(a, node_emb) for a in actions]
                    ).to(device)

                    logits = policy(g, action_embs)
                    probs = F.softmax(logits, dim=0)
                    dist = torch.distributions.Categorical(probs)

                    idx = dist.sample()
                    action = actions[idx.item()]

                    reward, done = env.step(action)

                    memory.append({
                        "A": env.A.copy(),
                        "T": env.T.copy(),
                        "action": action,
                        "old_log_prob": dist.log_prob(idx).detach(),
                        "reward": reward,
                        "done": done
                    })

                    if done:
                        break

        # =========================
        # Compute returns
        # =========================
        returns = []
        G = 0
        for step in reversed(memory):
            G = step["reward"] + GAMMA * G
            returns.insert(0, G)

        returns = torch.tensor(returns, dtype=torch.float32).to(device)

        # =========================
        # PPO UPDATE (RECOMPUTE EVERYTHING)
        # =========================
        for _ in range(PPO_EPOCHS):
            total_loss = 0.0

            for step, R in zip(memory, returns):
                data = encode_graph(step["A"], step["T"])
                batch = Batch.from_data_list([data]).to(device)

                g, node_emb = encoder(batch)

                action = step["action"]
                action_emb = encode_action(action, node_emb).unsqueeze(0).to(device)

                logits = policy(g, action_emb)
                log_prob = F.log_softmax(logits, dim=0)[0]

                ratio = torch.exp(log_prob - step["old_log_prob"])

                advantage = R - value_net(g).detach()

                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * advantage
                policy_loss = -torch.min(surr1, surr2)

                value_loss = F.mse_loss(value_net(g), R.unsqueeze(0))

                total_loss += policy_loss + value_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        print(f"Episode {ep} finished")

    return encoder, policy, value_net



In [2]:
@torch.no_grad()
def run_inference(env, encoder, policy, deterministic=True, max_steps=1000):
    """
    env          : GraphEnv
    encoder      : trained GINEncoder
    policy       : trained PolicyNet
    deterministic: if True -> greedy, else sample
    max_steps    : safety cap
    """

    encoder.eval()
    policy.eval()

    total_reward = 0.0
    steps = 0
    trajectory = []

    while not env.done() and steps < max_steps:
        actions = enumerate_actions(env)
        if not actions:
            break

        # Encode graph
        data = encode_graph(env.A, env.T)
        batch = Batch.from_data_list([data]).to(device)

        g, node_emb = encoder(batch)

        # Encode actions
        action_embs = torch.stack(
            [encode_action(a, node_emb) for a in actions]
        ).to(device)

        # Policy logits
        logits = policy(g, action_embs)
        probs = F.softmax(logits, dim=0)

        if deterministic:
            idx = torch.argmax(probs)
        else:
            dist = torch.distributions.Categorical(probs)
            idx = dist.sample()

        action = actions[idx.item()]
        reward, done = env.step(action)

        total_reward += reward
        trajectory.append((action, reward))

        steps += 1
        if done:
            break

    return {
        "total_reward": total_reward,
        "steps": steps,
        "trajectory": trajectory,
        "done": env.done()
    }


In [3]:
import pickle
import random
import numpy as np
import torch

# =========================
# 0. Reproducibility
# =========================
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

# =========================
# 1. Load fixed graphs
# =========================
with open("fixed_graph_dataset.pkl", "rb") as f:
    data = pickle.load(f)

TRAIN_GRAPHS = data["train"]
TEST_GRAPHS  = data["test"]

def make_envs(graphs):
    # IMPORTANT: copy so envs don't share state
    return [GraphEnv(A.copy(), T.copy()) for A, T in graphs]


# =========================
# 2. Training graphs (FIXED)
# =========================
train_envs = make_envs(TRAIN_GRAPHS)

print("Creating training graphs...")
for k, (A, _) in enumerate(TRAIN_GRAPHS):
    print(f"Train graph {k}: n={A.shape[0]}")


# =========================
# 3. Train PPO
# =========================
print("\nTraining PPO...")
encoder, policy, value = train(
    train_envs,
    episodes=300
)

Creating training graphs...
Train graph 0: n=11
Train graph 1: n=12
Train graph 2: n=8
Train graph 3: n=9
Train graph 4: n=12

Training PPO...
Episode 0 finished
Episode 1 finished
Episode 2 finished
Episode 3 finished
Episode 4 finished
Episode 5 finished
Episode 6 finished
Episode 7 finished
Episode 8 finished
Episode 9 finished
Episode 10 finished
Episode 11 finished
Episode 12 finished
Episode 13 finished
Episode 14 finished
Episode 15 finished
Episode 16 finished
Episode 17 finished
Episode 18 finished
Episode 19 finished
Episode 20 finished
Episode 21 finished
Episode 22 finished
Episode 23 finished
Episode 24 finished
Episode 25 finished
Episode 26 finished
Episode 27 finished
Episode 28 finished
Episode 29 finished
Episode 30 finished
Episode 31 finished
Episode 32 finished
Episode 33 finished
Episode 34 finished
Episode 35 finished
Episode 36 finished
Episode 37 finished
Episode 38 finished
Episode 39 finished
Episode 40 finished
Episode 41 finished
Episode 42 finished
Episode

In [4]:
# =========================
# 4. Test on ALL fixed graphs
# =========================
print("\nRunning inference on fixed test graphs...")

all_results = []

for idx, (A_test, T_test) in enumerate(TEST_GRAPHS):
    test_env = GraphEnv(A_test.copy(), T_test.copy())

    result = run_inference(
        test_env,
        encoder,
        policy,
        deterministic=True
    )

    all_results.append({
        "graph_id": idx,
        "solved": result["done"],
        "steps": result["steps"],
        "cost": -result["total_reward"]
    })

    print(f"\nTest Graph {idx}")
    print("  n:", A_test.shape[0])
    print("  Steps:", result["steps"])
    print("  Total reward:", result["total_reward"])
    print("  Final T:", test_env.T)



Running inference on fixed test graphs...

Test Graph 0
  n: 10
  Steps: 19
  Total reward: -113.80000000000001
  Final T: [0 0 0 1 0 0 0 1 0 0]

Test Graph 1
  n: 10
  Steps: 22
  Total reward: -148.79999999999995
  Final T: [0 1 0 0 0 1 0 1 1 0]

Test Graph 2
  n: 14
  Steps: 40
  Total reward: -307.0
  Final T: [0 0 0 0 0 0 1 0 0 0 0 0 1 0]

Test Graph 3
  n: 14
  Steps: 24
  Total reward: -124.80000000000003
  Final T: [0 0 0 0 0 0 0 0 0 1 1 0 0 0]

Test Graph 4
  n: 13
  Steps: 29
  Total reward: -189.39999999999995
  Final T: [0 0 0 1 0 0 0 0 0 1 1 1 0]
