1) Install required libraries

In [None]:
!pip install torch torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m814.0 kB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuf

2) Import libraries

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import numpy as np
import random
import torch.nn as nn

3) Load the Cora dataset and transfer it to GPU

In [None]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# Load Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = dataset[0].to(device)

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


4) Define GCN Model

In [None]:
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x


5) Define RGCN Model

In [None]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, dropedge_rate=0.5, noise_std=0.4):
        super(RGCN, self).__init__()
        self.conv1 = GCNConv(in_feats, hidden_feats)
        self.conv2 = GCNConv(hidden_feats, out_feats)
        self.dropedge_rate = dropedge_rate
        self.noise_std = noise_std

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        edge_index = self.drop_edge(edge_index)

        x = self.conv1(x, edge_index)
        x = F.relu(x)

        if self.training:
            noise = torch.randn_like(x) * self.noise_std
            x = x + noise

        x = self.conv2(x, edge_index)
        return x

    def drop_edge(self, edge_index):
        num_edges = edge_index.size(1)
        perm = torch.randperm(num_edges, device=edge_index.device)
        num_keep = int(num_edges * (1.0 - self.dropedge_rate))
        perm = perm[:num_keep]
        return edge_index[:, perm]


6) Training and Evaluation Functions

In [None]:
def train(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

def test(model, data):
    model.eval()
    logits = model(data)
    loss = F.cross_entropy(logits[data.test_mask], data.y[data.test_mask]).item()
    pred = logits[data.test_mask].max(1)[1]
    acc = pred.eq(data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    return acc, loss

7) Edge and Feature Flipping Functions

In [None]:
def flip_edge(edge_index, u, v):
    mask = ~(((edge_index[0] == u) & (edge_index[1] == v)) |
             ((edge_index[0] == v) & (edge_index[1] == u)))
    if mask.sum() == edge_index.size(1):
        new_edges = torch.cat([edge_index, torch.tensor([[u, v], [v, u]], dtype=torch.long, device=edge_index.device)], dim=1)
    else:
        new_edges = edge_index[:, mask]
    return new_edges

def flip_feature(features, u, idx):
    features[u, idx] = 1 - features[u, idx]
    return features


8) Build Nettack Attack

In [None]:
def run_attack(model, data, num_target_nodes=200, budget=15):
    model.eval()
    logits = model(data)
    conf = F.softmax(logits[data.test_mask], dim=1)
    conf_max, pred = conf.max(dim=1)
    true = data.y[data.test_mask]
    correct_mask = (pred == true)

    target_candidates = data.test_mask.nonzero(as_tuple=False).view(-1)[correct_mask]
    target_conf = conf_max[correct_mask]

    sorted_idx = target_conf.argsort()
    target_nodes = target_candidates[sorted_idx[:num_target_nodes]]

    successful_attacks = 0
    modified_links = 0

    x_adv_global = data.x.clone()
    edge_adv_global = data.edge_index.clone()

    correct_after_target = []

    for node in target_nodes:
        node = node.item()
        x_adv = x_adv_global.clone()
        edge_adv = edge_adv_global.clone()

        logits = model(data)
        pred = logits[node].argmax().item()
        true_label = data.y[node].item()
        if pred != true_label:
            continue

        for _ in range(budget):
            best_score = -np.inf
            best_action = None

            possible_neighbors = random.sample(range(data.num_nodes), min(20, data.num_nodes))
            possible_features = random.sample(range(data.num_node_features), min(20, data.num_node_features))

            for neighbor in possible_neighbors:
                if neighbor == node:
                    continue
                temp_edge = flip_edge(edge_adv, node, neighbor)
                temp_data = data.clone()
                temp_data.edge_index = temp_edge
                temp_data.x = x_adv
                temp_logits = model(temp_data)
                score = F.cross_entropy(temp_logits[[node]], data.y[[node]]).item()
                if score > best_score:
                    best_score = score
                    best_action = ('edge', neighbor)

            for idx in possible_features:
                temp_x = x_adv.clone()
                temp_x = flip_feature(temp_x, node, idx)
                temp_data = data.clone()
                temp_data.edge_index = edge_adv
                temp_data.x = temp_x
                temp_logits = model(temp_data)
                score = F.cross_entropy(temp_logits[[node]], data.y[[node]]).item()
                if score > best_score:
                    best_score = score
                    best_action = ('feature', idx)

            if best_action[0] == 'edge':
                neighbor = best_action[1]
                edge_adv = flip_edge(edge_adv, node, neighbor)
                modified_links += 1
            else:
                idx = best_action[1]
                x_adv = flip_feature(x_adv, node, idx)

            temp_data = data.clone()
            temp_data.edge_index = edge_adv
            temp_data.x = x_adv
            temp_logits = model(temp_data)
            pred_after = temp_logits[node].argmax().item()

            if pred_after != true_label:
                successful_attacks += 1
                correct_after_target.append(0)
                edge_adv_global = edge_adv
                x_adv_global = x_adv
                break
        else:
            correct_after_target.append(1)
            edge_adv_global = edge_adv
            x_adv_global = x_adv

    data.edge_index = edge_adv_global
    data.x = x_adv_global

    acc_after, loss_after = test(model, data)
    ASR = (successful_attacks / len(target_nodes)) * 100
    AML = modified_links / len(target_nodes)
    acc_target_after = 1 - np.mean(correct_after_target)

    return acc_after, loss_after, ASR, AML, acc_target_after


9) GCN Train and Attack Evaluation

In [None]:
# ======== Run Experiment GCN ========
print("\n=== Training GCN ===")
model_gcn = GCN().to(device)
optimizer_gcn = torch.optim.Adam(model_gcn.parameters(), lr=0.01, weight_decay=5e-4)
for epoch in range(200):
    train(model_gcn, data, optimizer_gcn)

acc_gcn_before, loss_gcn_before = test(model_gcn, data)
print(f"[GCN] Accuracy BEFORE attack: {acc_gcn_before:.4f}, Loss: {loss_gcn_before:.4f}")

acc_gcn_after, loss_gcn_after, ASR_gcn, AML_gcn, acc_target_gcn_after = run_attack(model_gcn, data.clone())
print("\n=== GCN under NETTACK ===")
print(f"[GCN] Accuracy AFTER  attack: {acc_gcn_after:.4f}, Loss: {loss_gcn_after:.4f}")
print(f"[GCN] ASR: {ASR_gcn:.2f}%, AML: {AML_gcn:.4f}")
print(f"[GCN] Target Nodes Accuracy AFTER attack: {acc_target_gcn_after:.4f}")


=== Training GCN ===
[GCN] Accuracy BEFORE attack: 0.8040, Loss: 0.6082

=== GCN under NETTACK ===
[GCN] Accuracy AFTER  attack: 0.6150, Loss: 0.8244
[GCN] ASR: 100.00%, AML: 1.2750
[GCN] Target Nodes Accuracy AFTER attack: 1.0000


10) RGCN Train and Attack Evaluation

In [None]:
# ======== Run Experiment RGCN ========
model_rgcn = RGCN(
    dataset.num_node_features,
    16,
    dataset.num_classes,
    dropedge_rate=0.5,
    noise_std=0.4
).to(device)

optimizer_rgcn = torch.optim.Adam(model_rgcn.parameters(), lr=0.01, weight_decay=5e-4)

# Train
for epoch in range(200):
    train(model_rgcn, data, optimizer_rgcn)

acc_rgcn_before, loss_rgcn_before = test(model_rgcn, data)
print(f"[RGCN] Accuracy BEFORE attack: {acc_rgcn_before:.4f}, Loss: {loss_rgcn_before:.4f}")

acc_rgcn_after, loss_rgcn_after, ASR_rgcn, AML_rgcn, acc_target_rgcn_after = run_attack(model_rgcn, data.clone())
print("\n=== RGCN under NETTACK ===")
print(f"[RGCN] Accuracy AFTER  attack: {acc_rgcn_after:.4f}, Loss: {loss_rgcn_after:.4f}")
print(f"[RGCN] ASR: {ASR_rgcn:.2f}%, AML: {AML_rgcn:.4f}")
print(f"[RGCN] Target Nodes Accuracy AFTER attack: {acc_target_rgcn_after:.4f}")

[RGCN] Accuracy BEFORE attack: 0.7550, Loss: 0.8876

=== RGCN under NETTACK ===
[RGCN] Accuracy AFTER  attack: 0.6880, Loss: 1.0857
[RGCN] ASR: 71.00%, AML: 1.2550
[RGCN] Target Nodes Accuracy AFTER attack: 0.9930
