# AoI Graph Freshness with GNN-DQN
Notebook for Part 3a and Part 3b

In [1]:
import os, random, math
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.optim as optim

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

device: cuda


In [3]:
import torch, sys, platform

print("Python:", sys.version)
print("Platform:", platform.platform())
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("torch.version.cuda:", torch.version.cuda)

def pyg_wheel_url():
    tv = torch.__version__.split("+")[0]  # "2.2.1"
    if torch.cuda.is_available() and torch.version.cuda is not None:
        cu = "cu" + torch.version.cuda.replace(".", "")
    else:
        cu = "cpu"
    return f"https://data.pyg.org/whl/torch-{tv}+{cu}.html"

url = pyg_wheel_url()
print("PyG wheel URL:", url)

Python: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
Platform: Linux-6.6.105+-x86_64-with-glibc2.35
Torch: 2.6.0+cu124
CUDA available: True
torch.version.cuda: 12.4
PyG wheel URL: https://data.pyg.org/whl/torch-2.6.0+cu124.html


In [4]:
import sys
!{sys.executable} -m pip install -q torch_geometric -f {url}

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h

In [5]:
import torch
import torch_geometric
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

print("torch_geometric:", torch_geometric.__version__)

# tiny graph: 3 nodes, edges 0-1, 1-2 (undirected)
edge_index = torch.tensor([[0,1,1,2],
                           [1,0,2,1]], dtype=torch.long)
x = torch.randn(3, 4)  # node features

data = Data(x=x, edge_index=edge_index)

conv = GCNConv(4, 8)
out = conv(data.x, data.edge_index)
print("GCNConv OK, out shape:", out.shape)

torch_geometric: 2.7.0
GCNConv OK, out shape: torch.Size([3, 8])


## Common setup and environment

In [6]:
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import networkx as nx

@dataclass
class AoIEnvConfig:
    num_nodes: int = 10
    p: float = 0.8
    lambda_cost: float = 0.0
    max_age: int = 100
    max_steps: int = 50
    init_age_low: int = 0
    init_age_high: int = 10
    graph_type: str = "line"
    seed: int = 42

class AoIEnv:
    def __init__(self, config: AoIEnvConfig):
        self.cfg = config
        self.rng = np.random.RandomState(config.seed)
        self._build_graph()
        self.num_nodes = self.graph.number_of_nodes()
        self.state = None
        self.step_count = 0
        self.action_n = self.num_nodes

    def _build_graph(self):
        if self.cfg.graph_type == "line":
            G = nx.path_graph(self.cfg.num_nodes)
        elif self.cfg.graph_type == "star":
            G = nx.star_graph(self.cfg.num_nodes - 1) 
        elif self.cfg.graph_type == "erdos":
            while True:
                G = nx.erdos_renyi_graph(self.cfg.num_nodes, 0.3, seed=self.cfg.seed)
                if nx.is_connected(G):
                    break
        else:
            raise ValueError(f"Unknown graph_type: {self.cfg.graph_type}")
        self.graph = G

    def reset(self, init_ages: Optional[np.ndarray] = None) -> np.ndarray:
        self.step_count = 0
        if init_ages is None:
            ages = self.rng.randint(self.cfg.init_age_low, self.cfg.init_age_high + 1, size=self.num_nodes)
        else:
            ages = np.array(init_ages, dtype=np.float32)
        self.state = ages.astype(np.float32)
        return self.state.copy()

    def step(self, action: int) -> Tuple[np.ndarray, float, bool, dict]:
        assert 0 <= action < self.num_nodes
        ages = self.state.astype(np.float32)

        ages = ages + 1.0
        ages[action] = 0.0

        for j in self.graph.neighbors(action):
            if self.rng.rand() < self.cfg.p:
                ages[j] = 0.0

        ages = np.clip(ages, 0.0, float(self.cfg.max_age))
        self.state = ages
        self.step_count += 1

        avg_aoi = float(np.mean(self.state))
        reward = -avg_aoi - self.cfg.lambda_cost
        done = self.step_count >= self.cfg.max_steps
        info = {"avg_aoi": avg_aoi, "step": self.step_count, "action": action}
        return self.state.copy(), reward, done, info

In [7]:
import numpy as np

def random_policy(state, env):
    return env.rng.randint(env.action_n)

def greedy_stale_policy(state, env):
    return int(np.argmax(state))

def degree_weighted_policy(state, env):
    deg = np.array([env.graph.degree[i] for i in range(env.num_nodes)], dtype=np.float32)
    return int(np.argmax(state * deg))

In [None]:
def evaluate(env, policy_fn, episodes=50):
    ms = []
    for _ in range(episodes):
        s = env.reset()
        aoi_sum = 0.0
        for _ in range(env.cfg.max_steps):
            a = policy_fn(s, env)
            s, r, done, info = env.step(a)
            aoi_sum += info["avg_aoi"]
            if done:
                break
        ms.append(aoi_sum / env.cfg.max_steps)
    return float(np.mean(ms)), float(np.std(ms))

In [15]:
import torch
from torch_geometric.data import Data

def nx_to_edge_index(G: nx.Graph) -> torch.Tensor:
    edges = []
    for u, v in G.edges():
        edges.append((u, v))
        edges.append((v, u))  # make it directed both ways
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index

def make_pyg_data(env: AoIEnv, state: np.ndarray, device: torch.device) -> Data:
    N = env.num_nodes
    age = torch.tensor(state / float(env.cfg.max_age),
                       dtype=torch.float32, device=device).view(N, 1)
    x = torch.cat([age, env._deg], dim=1)  # [N,2]
    return Data(x=x, edge_index=env._edge_index)

In [16]:
def cache_graph_tensors(env, device):
    # edge_index считаем один раз
    env._edge_index = nx_to_edge_index(env.graph).to(device)

    # degree считаем один раз
    N = env.num_nodes
    env._deg = torch.tensor(
        [env.graph.degree[i] / float(max(1, N-1)) for i in range(N)],
        dtype=torch.float32, device=device
    ).view(N, 1)

In [17]:
import torch.nn as nn
from torch_geometric.nn import GCNConv

class GNNQNet(nn.Module):
    def __init__(self, in_dim=2, hidden=64):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.head  = nn.Linear(hidden, 1)  # per-node Q

    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        q = self.head(x).squeeze(-1)  # [N]
        return q

In [18]:
import random
from collections import deque

class ReplayBuffer:
    def __init__(self, capacity=20000):
        self.buf = deque(maxlen=capacity)

    def push(self, s, a, r, ns, done):
        self.buf.append((s, a, r, ns, done))

    def sample(self, batch_size):
        batch = random.sample(self.buf, batch_size)
        s, a, r, ns, done = map(np.array, zip(*batch))
        return s, a, r, ns, done

    def __len__(self):
        return len(self.buf)

## Part 3a  Stable GNN-DQN on fixed graph types

In [25]:
import torch
import torch.optim as optim
from torch_geometric.data import Batch

def make_pyg_batch(env, states_np, device):
    data_list = [make_pyg_data(env, s, device) for s in states_np]
    return Batch.from_data_list(data_list)

def train_gnn_dqn_stable(env: AoIEnv,
                         episodes=300,
                         gamma=0.99,
                         lr=3e-4,
                         batch_size=64,
                         buffer_size=20000,
                         warmup=800,
                         target_update=400,
                         eps_start=1.0,
                         eps_end=0.05,
                         eps_decay=0.995,
                         device=torch.device("cpu")):

    q_net = GNNQNet(in_dim=2, hidden=64).to(device)
    target_net = GNNQNet(in_dim=2, hidden=64).to(device)
    target_net.load_state_dict(q_net.state_dict())
    target_net.eval()

    opt = optim.Adam(q_net.parameters(), lr=lr)
    loss_fn = torch.nn.SmoothL1Loss()  # Huber
    rb = ReplayBuffer(buffer_size)

    total_steps = 0
    eps = eps_start
    N = env.num_nodes

    for ep in range(episodes):
        s = env.reset()
        for t in range(env.cfg.max_steps):
            if random.random() < eps:
                a = random.randrange(env.action_n)
            else:
                with torch.no_grad():
                    data = make_pyg_data(env, s, device)
                    q = q_net(data)
                    a = int(torch.argmax(q).item())

            ns, r, done, info = env.step(a)
            rb.push(s, a, r, ns, done)
            s = ns

            if len(rb) >= max(warmup, batch_size):
                bs, ba, br, bns, bd = rb.sample(batch_size)

                batch_s = make_pyg_batch(env, bs, device)
                q_all = q_net(batch_s).view(batch_size, N)
                ba_t = torch.tensor(ba, dtype=torch.long, device=device)
                q_sa = q_all[torch.arange(batch_size, device=device), ba_t]

                with torch.no_grad():
                    batch_ns = make_pyg_batch(env, bns, device)
                    q_next = target_net(batch_ns).view(batch_size, N)
                    q_next_max = q_next.max(dim=1).values

                    br_t = torch.tensor(br, dtype=torch.float32, device=device)
                    bd_t = torch.tensor(bd.astype(np.float32), dtype=torch.float32, device=device)
                    target = br_t + gamma * (1.0 - bd_t) * q_next_max

                loss = loss_fn(q_sa, target)
                opt.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(q_net.parameters(), 5.0)
                opt.step()

            total_steps += 1
            if total_steps % target_update == 0:
                target_net.load_state_dict(q_net.state_dict())

            if done:
                break

        eps = max(eps_end, eps * eps_decay)
        if (ep + 1) % 25 == 0:
            print("episode", ep + 1, "eps", round(eps, 3))

    return q_net

In [26]:
def gnn_policy_from_net(q_net, device):
    def pol(s, env):
        with torch.no_grad():
            data = make_pyg_data(env, s, device)
            q = q_net(data)
            return int(torch.argmax(q).item())
    return pol

def action_hist(env, policy_fn, episodes=30):
    counts = np.zeros(env.num_nodes, dtype=int)
    for _ in range(episodes):
        s = env.reset()
        for _ in range(env.cfg.max_steps):
            a = policy_fn(s, env)
            counts[a] += 1
            s, r, done, info = env.step(a)
            if done:
                break
    print("action counts:", counts)
    print("action probs :", np.round(counts / counts.sum(), 3))

### What happens here
We train a GNN-DQN on a fixed graph type and compare it against simple baselines.

In [27]:
to_run = ["line", "erdos"]

for gtype in to_run:
    cfg = AoIEnvConfig(num_nodes=10, graph_type=gtype, p=0.8, max_steps=50, seed=42)
    env = AoIEnv(cfg)
    cache_graph_tensors(env, device)

    print(f"\nTraining STABLE GNN-DQN on {gtype}...")
    q_net_stable = train_gnn_dqn_stable(env, episodes=300, device=device)

    pols = {
        "random": random_policy,
        "greedy": greedy_stale_policy,
        "degree_weighted": degree_weighted_policy,
        "gnn_dqn_stable": gnn_policy_from_net(q_net_stable, device),
    }

    for name, pol in pols.items():
        m, sd = evaluate(env, pol, episodes=50)
        print(f"{gtype:5s} | {name:16s} mean_AoI={m:.3f} ± {sd:.3f}")

    print("\nAction histogram for stable policy:")
    action_hist(env, gnn_policy_from_net(q_net_stable, device), episodes=30)

    # optional save
    import os
    os.makedirs("artifacts", exist_ok=True)
    torch.save(q_net_stable.state_dict(), f"artifacts/gnn_dqn_stable_{gtype}.pt")
    print("Saved weights to:", f"artifacts/gnn_dqn_stable_{gtype}.pt")


Training STABLE GNN-DQN on line...
episode 25 eps 0.882
episode 50 eps 0.778
episode 75 eps 0.687
episode 100 eps 0.606
episode 125 eps 0.534
episode 150 eps 0.471
episode 175 eps 0.416
episode 200 eps 0.367
episode 225 eps 0.324
episode 250 eps 0.286
episode 275 eps 0.252
episode 300 eps 0.222
line  | random           mean_AoI=3.335 ± 0.376
line  | greedy           mean_AoI=2.097 ± 0.105
line  | degree_weighted  mean_AoI=2.018 ± 0.113
line  | gnn_dqn_stable   mean_AoI=2.390 ± 0.167

Action histogram for stable policy:
action counts: [214   0   0 488  46  73 464   0   0 215]
action probs : [0.143 0.    0.    0.325 0.031 0.049 0.309 0.    0.    0.143]
Saved weights to: artifacts/gnn_dqn_stable_line.pt

Training STABLE GNN-DQN on erdos...
episode 25 eps 0.882
episode 50 eps 0.778
episode 75 eps 0.687
episode 100 eps 0.606
episode 125 eps 0.534
episode 150 eps 0.471
episode 175 eps 0.416
episode 200 eps 0.367
episode 225 eps 0.324
episode 250 eps 0.286
episode 275 eps 0.252
episode 300 e

In [33]:
import torch, time
print("device:", device)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    print("cuda mem allocated (MB):", torch.cuda.memory_allocated() / 1024**2)

device: cuda
cuda available: True
GPU: Tesla T4
cuda mem allocated (MB): 16.39697265625


In [34]:
def make_env(graph_type, seed, num_nodes=10, p=0.8, max_steps=50):
    cfg = AoIEnvConfig(num_nodes=num_nodes, graph_type=graph_type, p=p, max_steps=max_steps, seed=seed)
    env = AoIEnv(cfg)
    cache_graph_tensors(env, device)
    return env

## Part 3b  Generalization across unseen graphs

In [41]:
import networkx as nx
import numpy as np

def _build_graph_fixed(self):
    n = self.cfg.num_nodes

    if self.cfg.graph_type == "line":
        G = nx.path_graph(n)

    elif self.cfg.graph_type == "star":
        G = nx.star_graph(n - 1)

    elif self.cfg.graph_type == "erdos":
        p_edge = 0.3

        # пробуем много разных seed, а не один и тот же
        G = None
        for _ in range(200):
            sd = int(self.rng.randint(0, 2**31 - 1))
            cand = nx.erdos_renyi_graph(n, p_edge, seed=sd)
            if nx.is_connected(cand):
                G = cand
                break

        # страховка: если вдруг не нашли связный, "склеим" компоненты вручную
        if G is None:
            sd = int(self.rng.randint(0, 2**31 - 1))
            G = nx.erdos_renyi_graph(n, p_edge, seed=sd)
            comps = list(nx.connected_components(G))
            for c1, c2 in zip(comps[:-1], comps[1:]):
                u = int(self.rng.choice(list(c1)))
                v = int(self.rng.choice(list(c2)))
                G.add_edge(u, v)

    else:
        raise ValueError(f"Unknown graph_type: {self.cfg.graph_type}")

    self.graph = G

# применяем патч
AoIEnv._build_graph = _build_graph_fixed
print("Patched AoIEnv._build_graph")

Patched AoIEnv._build_graph


In [55]:
import numpy as np
import random
import torch
import torch.optim as optim
from collections import deque
from torch_geometric.data import Batch
import os

def make_pyg_batch(env, states_np):
    data_list = [make_pyg_data(env, s, device) for s in states_np]
    return Batch.from_data_list(data_list)

def train_gnn_dqn_generalize_v5(
    episodes=800,
    train_graph_type="erdos",
    num_nodes=10,
    p=0.8,
    max_steps=50,          # можно оставить 50
    gamma=0.99,
    lr=3e-4,
    batch_size=32,         # <= max_steps
    warmup_steps=32,       # <= max_steps
    target_update=400,
    eps_start=1.0,
    eps_end=0.1,
    eps_decay=0.9985,
    reward_scale=None,     # если None, возьмём env.cfg.max_age
    ep_replay_max=2000,    # достаточно
    base_seed=3000,
    log_every=50,
    save_every=200,
    save_prefix="artifacts/gnn_dqn_generalized_erdos_v5"
):
    q_net = GNNQNet(in_dim=2, hidden=64).to(device)
    target_net = GNNQNet(in_dim=2, hidden=64).to(device)
    target_net.load_state_dict(q_net.state_dict())
    target_net.eval()

    opt = optim.Adam(q_net.parameters(), lr=lr)
    loss_fn = torch.nn.SmoothL1Loss()

    print("q_net device:", next(q_net.parameters()).device)
    os.makedirs("artifacts", exist_ok=True)

    eps = eps_start
    total_steps = 0

    for ep in range(episodes):
        env = make_env(train_graph_type, seed=base_seed + ep, num_nodes=num_nodes, p=p, max_steps=max_steps)
        scale = float(env.cfg.max_age) if reward_scale is None else float(reward_scale)

        ep_replay = deque(maxlen=ep_replay_max)
        s = env.reset()
        ep_aoi_sum = 0.0
        updates = 0

        for t in range(env.cfg.max_steps):
            if random.random() < eps:
                a = random.randrange(env.action_n)
            else:
                with torch.no_grad():
                    data = make_pyg_data(env, s, device)
                    q = q_net(data)
                    a = int(torch.argmax(q).item())

            ns, r, done, info = env.step(a)
            ep_aoi_sum += info["avg_aoi"]

            r_norm = r / scale
            ep_replay.append((s, a, r_norm, ns, done))
            s = ns

            if len(ep_replay) >= max(warmup_steps, batch_size):
                batch = random.sample(ep_replay, batch_size)
                bs  = np.array([x[0] for x in batch])
                ba  = np.array([x[1] for x in batch])
                br  = np.array([x[2] for x in batch])
                bns = np.array([x[3] for x in batch])
                bd  = np.array([x[4] for x in batch])

                batch_s = make_pyg_batch(env, bs)
                N = env.num_nodes
                q_all = q_net(batch_s).view(batch_size, N)
                ba_t = torch.tensor(ba, dtype=torch.long, device=device)
                q_sa = q_all[torch.arange(batch_size, device=device), ba_t]

                with torch.no_grad():
                    batch_ns = make_pyg_batch(env, bns)
                    q_next = target_net(batch_ns).view(batch_size, N)
                    q_next_max = q_next.max(dim=1).values

                    br_t = torch.tensor(br, dtype=torch.float32, device=device)
                    bd_t = torch.tensor(bd.astype(np.float32), dtype=torch.float32, device=device)
                    target = br_t + gamma * (1.0 - bd_t) * q_next_max

                loss = loss_fn(q_sa, target)
                opt.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(q_net.parameters(), 5.0)
                opt.step()

                updates += 1

            total_steps += 1
            if total_steps % target_update == 0:
                target_net.load_state_dict(q_net.state_dict())

            if done:
                break

        eps = max(eps_end, eps * eps_decay)

        if (ep % log_every) == 0:
            print("ep", ep,
                  "eps", round(eps, 3),
                  "mean_aoi_ep", round(ep_aoi_sum / max_steps, 3),
                  "updates", updates)

        if (ep > 0) and (ep % save_every == 0):
            path = f"{save_prefix}_ep{ep}.pt"
            torch.save(q_net.state_dict(), path)
            print("checkpoint saved:", path)

    final_path = f"{save_prefix}.pt"
    torch.save(q_net.state_dict(), final_path)
    print("final saved:", final_path)
    return q_net

print("Training generalized GNN-DQN v5 on ERDOS distribution...")
q_net_gen_v5 = train_gnn_dqn_generalize_v5()

Training generalized GNN-DQN v5 on ERDOS distribution...
q_net device: cuda:0
ep 0 eps 0.999 mean_aoi_ep 1.796 updates 19
ep 50 eps 0.926 mean_aoi_ep 2.782 updates 19
ep 100 eps 0.859 mean_aoi_ep 3.19 updates 19
ep 150 eps 0.797 mean_aoi_ep 3.37 updates 19
ep 200 eps 0.74 mean_aoi_ep 1.698 updates 19
checkpoint saved: artifacts/gnn_dqn_generalized_erdos_v5_ep200.pt
ep 250 eps 0.686 mean_aoi_ep 2.874 updates 19
ep 300 eps 0.636 mean_aoi_ep 3.108 updates 19
ep 350 eps 0.59 mean_aoi_ep 2.062 updates 19
ep 400 eps 0.548 mean_aoi_ep 2.774 updates 19
checkpoint saved: artifacts/gnn_dqn_generalized_erdos_v5_ep400.pt
ep 450 eps 0.508 mean_aoi_ep 5.0 updates 19
ep 500 eps 0.471 mean_aoi_ep 3.604 updates 19
ep 550 eps 0.437 mean_aoi_ep 2.252 updates 19
ep 600 eps 0.406 mean_aoi_ep 2.908 updates 19
checkpoint saved: artifacts/gnn_dqn_generalized_erdos_v5_ep600.pt
ep 650 eps 0.376 mean_aoi_ep 3.982 updates 19
ep 700 eps 0.349 mean_aoi_ep 5.052 updates 19
ep 750 eps 0.324 mean_aoi_ep 2.79 updates 1

### Evaluation and selection
We evaluate multiple checkpoints on unseen graphs and keep the best model.

In [65]:
import torch, numpy as np

def load_qnet(path):
    q = GNNQNet(in_dim=2, hidden=64).to(device)
    sd = torch.load(path, map_location=device)
    q.load_state_dict(sd)
    q.eval()
    return q

def eval_family(q_net, graph_type, seeds, episodes_each=40, num_nodes=10, p=0.8, max_steps=50):
    pol_gnn = gnn_policy_from_net(q_net, device)

    out = {}
    for name, pol in {
        "random": random_policy,
        "greedy": greedy_stale_policy,
        "degree_weighted": degree_weighted_policy,
        "gnn_generalized": pol_gnn
    }.items():
        ms = []
        for sd in seeds:
            env = make_env(graph_type, seed=sd, num_nodes=num_nodes, p=p, max_steps=max_steps)
            m, _ = evaluate(env, pol, episodes=episodes_each)
            ms.append(m)
        out[name] = (float(np.mean(ms)), float(np.std(ms)))
    return out

checkpoints = [
    "artifacts/gnn_dqn_generalized_erdos_v5_ep200.pt",
    "artifacts/gnn_dqn_generalized_erdos_v5_ep400.pt",
    "artifacts/gnn_dqn_generalized_erdos_v5_ep600.pt",
    "artifacts/gnn_dqn_generalized_erdos_v5.pt",
]

tests = {
    "erdos_new": ("erdos", [900, 901, 902, 903, 904]),
    "line":      ("line",  [910, 911, 912]),
    "star":      ("star",  [920, 921, 922]),
}

summary = []

for ckpt in checkpoints:
    q = load_qnet(ckpt)
    res_erdos = eval_family(q, "erdos", [900, 901, 902, 903, 904], episodes_each=40)
    m = res_erdos["gnn_generalized"][0]
    summary.append((ckpt, m))
    print("\nCKPT:", ckpt)
    print("erdos_new | gnn_generalized mean_AoI=", round(m, 3),
          "| degree_weighted=", round(res_erdos["degree_weighted"][0], 3),
          "| greedy=", round(res_erdos["greedy"][0], 3))

best_ckpt = sorted(summary, key=lambda x: x[1])[0][0]
print("\nBEST checkpoint by erdos_new:", best_ckpt)


CKPT: artifacts/gnn_dqn_generalized_erdos_v5_ep200.pt
erdos_new | gnn_generalized mean_AoI= 5.902 | degree_weighted= 1.431 | greedy= 1.622

CKPT: artifacts/gnn_dqn_generalized_erdos_v5_ep400.pt
erdos_new | gnn_generalized mean_AoI= 7.272 | degree_weighted= 1.431 | greedy= 1.622

CKPT: artifacts/gnn_dqn_generalized_erdos_v5_ep600.pt
erdos_new | gnn_generalized mean_AoI= 5.407 | degree_weighted= 1.431 | greedy= 1.622

CKPT: artifacts/gnn_dqn_generalized_erdos_v5.pt
erdos_new | gnn_generalized mean_AoI= 4.194 | degree_weighted= 1.431 | greedy= 1.622

BEST checkpoint by erdos_new: artifacts/gnn_dqn_generalized_erdos_v5.pt


In [66]:
best_q = load_qnet(best_ckpt)

for label, (gtype, seeds) in tests.items():
    res = eval_family(best_q, gtype, seeds, episodes_each=40)
    print("\nTEST", label, "graph", gtype)
    for k in ["random", "greedy", "degree_weighted", "gnn_generalized"]:
        m, s = res[k]
        print(f"{k:15s} mean_AoI={m:.3f}  (std over seeds {s:.3f})")


TEST erdos_new graph erdos
random          mean_AoI=2.345  (std over seeds 0.380)
greedy          mean_AoI=1.622  (std over seeds 0.257)
degree_weighted mean_AoI=1.431  (std over seeds 0.162)
gnn_generalized mean_AoI=4.194  (std over seeds 1.377)

TEST line graph line
random          mean_AoI=3.298  (std over seeds 0.034)
greedy          mean_AoI=2.105  (std over seeds 0.009)
degree_weighted mean_AoI=2.008  (std over seeds 0.018)
gnn_generalized mean_AoI=1.856  (std over seeds 0.019)

TEST star graph star
random          mean_AoI=4.090  (std over seeds 0.173)
greedy          mean_AoI=3.640  (std over seeds 0.025)
degree_weighted mean_AoI=2.069  (std over seeds 0.018)
gnn_generalized mean_AoI=0.250  (std over seeds 0.005)


In [67]:
import shutil, os
os.makedirs("artifacts", exist_ok=True)
shutil.copy(best_ckpt, "artifacts/gnn_dqn_generalized_best.pt")
print("Saved best checkpoint to artifacts/gnn_dqn_generalized_best.pt")

Saved best checkpoint to artifacts/gnn_dqn_generalized_best.pt


In [68]:
import numpy as np

def action_histogram(q_net, graph_type="star", seed=123, steps=1000):
    env = make_env(graph_type, seed=seed, num_nodes=10, p=0.8, max_steps=50)
    pol = gnn_policy_from_net(q_net, device)

    counts = np.zeros(env.action_n, dtype=int)
    s = env.reset()

    for _ in range(steps):
        a = pol(s, env)
        counts[a] += 1
        s, r, done, info = env.step(a)
        if done:
            s = env.reset()

    probs = counts / counts.sum()
    print("graph:", graph_type)
    print("action counts:", counts)
    print("action probs :", np.round(probs, 3))

best_q = load_qnet("artifacts/gnn_dqn_generalized_best.pt")
action_histogram(best_q, "star", seed=777, steps=2000)
action_histogram(best_q, "line", seed=888, steps=2000)

graph: star
action counts: [2000    0    0    0    0    0    0    0    0    0]
action probs : [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
graph: line
action counts: [  1 623  10 182 187 198 165  24 610   0]
action probs : [0.    0.312 0.005 0.091 0.094 0.099 0.082 0.012 0.305 0.   ]


## Reflection and observations

In this project, I studied Age of Information minimization on graphs using reinforcement learning.

In Part 3a, a GNN-DQN agent was trained on fixed graph types. In this setting, the learned policy showed reasonable behavior and achieved competitive performance compared to simple baselines on structured graphs.

In Part 3b, I trained a generalized GNN-DQN on a distribution of Erdos graphs and evaluated it on unseen graphs. The results show that the learned policy generalizes well to star and line graphs, but consistently underperforms simple heuristic baselines on random Erdos graphs.

This indicates that generalization across graph families is non-trivial and that simple heuristics remain strong baselines for AoI minimization. The experiments highlight the importance of careful evaluation and show that increased model complexity does not guarantee better performance.