<a href="https://colab.research.google.com/github/vinhqdang/attack_on_graph_link_prediction/blob/main/gcn_adv_vs_nettack.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuff

In [3]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import numpy as np
import random

In [4]:
# Set random seed
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0].to(device)


Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [5]:
# GCN model class
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x


In [6]:
# Train function
def train(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

# Test function
def test(model, data):
    model.eval()
    logits = model(data)
    loss = F.cross_entropy(logits[data.test_mask], data.y[data.test_mask]).item()
    pred = logits[data.test_mask].max(1)[1]
    acc = pred.eq(data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    return acc, loss

In [7]:
# Projection function
def projection(b, eps):
    b = torch.clamp(b, 0, 1)
    if b.sum() > eps:
        b = b * (eps / b.sum())
    return b

In [12]:
# Algorithm 1 adversarial training
def adversarial_train(model, data, T1=10, T2=5, K=5, eta=5e-2, gamma=1e-2, eps=20.0):
    num_edges = data.edge_index.shape[1]
    b = torch.ones(num_edges, requires_grad=True, device=device) * 0.5
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    for t1 in range(T1):
        b = b.detach().clone().requires_grad_(True)

        # Inner loop: update b
        for t2 in range(T2):
            edge_mask = torch.bernoulli(b).bool()  # sample perturbation
            perturbed_edge_index = data.edge_index[:, edge_mask]

            temp_data = data.clone()
            temp_data.edge_index = perturbed_edge_index

            out = model(temp_data)
            loss = F.cross_entropy(out[temp_data.train_mask], temp_data.y[temp_data.train_mask])

            grad_b = torch.autograd.grad(loss, b, retain_graph=True, allow_unused=True)[0]
            if grad_b is None:
                grad_b = torch.zeros_like(b)

            b = b + eta * grad_b
            b = projection(b, eps)

        # Sample K binary perturbations
        u_samples = []
        for k in range(K):
            u_k = torch.bernoulli(b).detach()
            u_samples.append(u_k)

        # Find b* that maximizes loss
        max_loss = -float('inf')
        b_star = None
        for u in u_samples:
            edge_mask = u.bool()
            perturbed_edge_index = data.edge_index[:, edge_mask]

            temp_data = data.clone()
            temp_data.edge_index = perturbed_edge_index

            out = model(temp_data)
            loss = F.cross_entropy(out[temp_data.train_mask], temp_data.y[temp_data.train_mask])

            if loss.item() > max_loss:
                max_loss = loss.item()
                b_star = u

        # Outer loop: update model parameters W
        optimizer.zero_grad()
        perturbed_edge_index = data.edge_index[:, b_star.bool()]

        temp_data = data.clone()
        temp_data.edge_index = perturbed_edge_index

        out = model(temp_data)
        loss = F.cross_entropy(out[temp_data.train_mask], temp_data.y[temp_data.train_mask])
        loss.backward()
        optimizer.step()

        print(f'Iter {t1+1}/{T1}, Loss: {loss.item():.4f}, b.sum(): {b.sum().item():.2f}')

    return model


In [9]:
# Functions to flip edge & feature
def flip_edge(edge_index, u, v):
    mask = ~(((edge_index[0] == u) & (edge_index[1] == v)) |
             ((edge_index[0] == v) & (edge_index[1] == u)))
    if mask.sum() == edge_index.size(1):
        new_edges = torch.cat([edge_index, torch.tensor([[u, v], [v, u]], dtype=torch.long, device=edge_index.device)], dim=1)
    else:
        new_edges = edge_index[:, mask]
    return new_edges

def flip_feature(features, u, idx):
    features[u, idx] = 1 - features[u, idx]
    return features

In [13]:
# GCN Baseline
model_gcn = GCN().to(device)
optimizer = torch.optim.Adam(model_gcn.parameters(), lr=0.01, weight_decay=5e-4)

print("\n=== Training GCN ===")
for epoch in range(200):
    train(model_gcn, data, optimizer)

acc_before_gcn, loss_before_gcn = test(model_gcn, data)
print(f"Accuracy BEFORE attack: {acc_before_gcn:.4f}, Loss: {loss_before_gcn:.4f}")



=== Training GCN ===
Accuracy BEFORE attack: 0.8120, Loss: 0.6189


In [14]:
# Adversarial Training
model_adv = GCN().to(device)
print("\n=== Adversarial Training (Algorithm 1) ===")
model_adv = adversarial_train(model_adv, data)
acc_before_adv, loss_before_adv = test(model_adv, data)
print(f"Accuracy BEFORE attack: {acc_before_adv:.4f}, Loss: {loss_before_adv:.4f}")



=== Adversarial Training (Algorithm 1) ===
Iter 1/10, Loss: 1.9449, b.sum(): 20.00
Iter 2/10, Loss: 1.7708, b.sum(): 20.00
Iter 3/10, Loss: 1.6154, b.sum(): 20.00
Iter 4/10, Loss: 1.4379, b.sum(): 20.00
Iter 5/10, Loss: 1.2448, b.sum(): 20.00
Iter 6/10, Loss: 1.0471, b.sum(): 20.00
Iter 7/10, Loss: 0.8732, b.sum(): 20.00
Iter 8/10, Loss: 0.7183, b.sum(): 20.00
Iter 9/10, Loss: 0.5822, b.sum(): 20.00
Iter 10/10, Loss: 0.4690, b.sum(): 20.00
Accuracy BEFORE attack: 0.6760, Loss: 1.4168


In [15]:
# Prepare target nodes
model_gcn.eval()
logits = model_gcn(data)
conf = F.softmax(logits[data.test_mask], dim=1)
conf_max, pred = conf.max(dim=1)
true = data.y[data.test_mask]
correct_mask = (pred == true)

# Chọn top 200 node dễ attack
num_target_nodes = 200
sorted_idx = conf_max[correct_mask].argsort()
target_candidates = data.test_mask.nonzero(as_tuple=False).view(-1)[correct_mask]
target_nodes = target_candidates[sorted_idx[:num_target_nodes]]

# Parameters for attack
budget = 15
num_edge_candidates = 20
num_feature_candidates = 20

# Attack Function (reuse for GCN & ADV)
def run_attack(model, data, target_nodes):
    successful_attacks = 0
    modified_links = 0
    correct_after_target = []

    x_adv_global = data.x.clone()
    edge_adv_global = data.edge_index.clone()

    for node in target_nodes:
        node = node.item()
        x_adv = x_adv_global.clone()
        edge_adv = edge_adv_global.clone()

        logits = model(data)
        pred = logits[node].argmax().item()
        true_label = data.y[node].item()

        if pred != true_label:
            continue

        for _ in range(budget):
            best_score = -np.inf
            best_action = None

            possible_neighbors = random.sample(range(data.num_nodes), min(num_edge_candidates, data.num_nodes))
            possible_features = random.sample(range(data.num_node_features), min(num_feature_candidates, data.num_node_features))

            for neighbor in possible_neighbors:
                if neighbor == node:
                    continue
                temp_edge = flip_edge(edge_adv, node, neighbor)
                temp_data = data.clone()
                temp_data.edge_index = temp_edge
                temp_data.x = x_adv
                temp_logits = model(temp_data)
                score = F.cross_entropy(temp_logits[[node]], data.y[[node]]).item()
                if score > best_score:
                    best_score = score
                    best_action = ('edge', neighbor)

            for idx in possible_features:
                temp_x = x_adv.clone()
                temp_x = flip_feature(temp_x, node, idx)
                temp_data = data.clone()
                temp_data.edge_index = edge_adv
                temp_data.x = temp_x
                temp_logits = model(temp_data)
                score = F.cross_entropy(temp_logits[[node]], data.y[[node]]).item()
                if score > best_score:
                    best_score = score
                    best_action = ('feature', idx)

            if best_action[0] == 'edge':
                neighbor = best_action[1]
                edge_adv = flip_edge(edge_adv, node, neighbor)
                modified_links += 1
            else:
                idx = best_action[1]
                x_adv = flip_feature(x_adv, node, idx)

            temp_data = data.clone()
            temp_data.edge_index = edge_adv
            temp_data.x = x_adv
            temp_logits = model(temp_data)
            pred_after = temp_logits[node].argmax().item()

            if pred_after != true_label:
                successful_attacks += 1
                correct_after_target.append(0)
                edge_adv_global = edge_adv
                x_adv_global = x_adv
                break
        else:
            correct_after_target.append(1)
            edge_adv_global = edge_adv
            x_adv_global = x_adv

    data.edge_index = edge_adv_global
    data.x = x_adv_global
    acc_after, loss_after = test(model, data)
    ASR = (successful_attacks / len(target_nodes)) * 100
    AML = modified_links / len(target_nodes)
    acc_target_after = 1 - np.mean(correct_after_target)

    return acc_after, loss_after, ASR, AML, acc_target_after

In [16]:
# Run attack for GCN
acc_after_gcn, loss_after_gcn, ASR_gcn, AML_gcn, acc_target_gcn = run_attack(model_gcn, data, target_nodes)

# Run attack for GCN + ADV
acc_after_adv, loss_after_adv, ASR_adv, AML_adv, acc_target_adv = run_attack(model_adv, data, target_nodes)

# Summary
print("GCN")
print(f"Accuracy BEFORE Attack: {acc_before_gcn:.4f}")
print(f"Accuracy AFTER  Attack: {acc_after_gcn:.4f}")
print(f"ASR: {ASR_gcn:.2f}%, AML: {AML_gcn:.4f}")

print("Adversarial Training")
print(f"Accuracy BEFORE Attack: {acc_before_adv:.4f}")
print(f"Accuracy AFTER  Attack: {acc_after_adv:.4f}")
print(f"ASR: {ASR_adv:.2f}%, AML: {AML_adv:.4f}")



=== GCN under NETTACK ===

=== GCN + Algorithm 1 under NETTACK ===

Accuracy BEFORE Attack: 0.8120
Accuracy AFTER  Attack: 0.6210
ASR: 100.00%, AML: 1.1900
Accuracy BEFORE Attack: 0.6760
Accuracy AFTER  Attack: 0.5730
ASR: 26.50%, AML: 0.6700
