In [18]:
import torch
import torch.nn.functional as F
import numpy as np
import random
from deeprobust.graph.targeted_attack import BaseAttack
from scipy.sparse import csr_matrix

class ClassConditionalFeatureAttack(BaseAttack):
    def __init__(self, model=None, nnodes=None, attack_structure=False, attack_features=True, device='cpu'):
        super(ClassConditionalFeatureAttack, self).__init__(model, nnodes, attack_structure, attack_features)
        self.device = device

    def attack(self, features: csr_matrix, adj: csr_matrix, labels: np.ndarray, 
               target_node=None, n_perturbations=10, swap_fraction=0.1, max_per_class=50):
        """
        Perform class-conditional feature swapping (suitable for CSR input).

        Args:
            features (csr_matrix): Node feature matrix (N x D)
            adj (csr_matrix): Adjacency matrix (not modified)
            labels (np.ndarray): Node labels
            target_node (int, optional): Target node (unused here)
            n_perturbations (int): Max number of feature swaps
            swap_fraction (float): Proportion of total nodes to swap
            max_per_class (int): Max number of nodes to swap per class
        """
        assert isinstance(features, csr_matrix), "Expected features to be a CSR sparse matrix"
        assert isinstance(adj, csr_matrix), "Expected adj to be a CSR sparse matrix"

        # Convert features to dense tensor for manipulation
        features_dense = torch.FloatTensor(features.todense()).to(self.device)
        labels = torch.LongTensor(labels).to(self.device)

        self.modified_features = features_dense.clone()

        # group node indices by class
        unique_labels = torch.unique(labels).tolist()
        class_to_indices = {int(c): (labels == c).nonzero(as_tuple=True)[0].tolist() for c in unique_labels}
        num_swaps = min(n_perturbations, int(swap_fraction * features.shape[0]))

        swaps_done = 0

        for c1 in unique_labels:
            if swaps_done >= num_swaps:
                break

            c1_indices = class_to_indices[c1]
            if len(c1_indices) == 0:
                continue

            for idx in random.sample(c1_indices, min(len(c1_indices), max_per_class)):
                if swaps_done >= num_swaps:
                    break

                other_classes = [c for c in unique_labels if c != c1]
                c2 = random.choice(other_classes)
                c2_indices = class_to_indices[c2]
                if len(c2_indices) == 0:
                    continue

                f1 = self.modified_features[idx].unsqueeze(0)  # shape: (1, D)
                f2s = self.modified_features[c2_indices]       # shape: (M, D)
                sims = F.cosine_similarity(f1, f2s)
                closest_idx = c2_indices[torch.argmax(sims).item()]

                # Swap features
                temp = self.modified_features[idx].clone()
                self.modified_features[idx] = self.modified_features[closest_idx]
                self.modified_features[closest_idx] = temp

                swaps_done += 1

        # Convert back to CSR sparse matrix if needed
        modified_features_np = self.modified_features.cpu().numpy()
        self.modified_features = csr_matrix(modified_features_np)

        return self.modified_features


In [22]:
from deeprobust.graph.data import Dataset
from deeprobust.graph.defense import GCN
from deeprobust.graph.utils import *

# Load Cora
data = Dataset(root='/tmp/', name='cora', setting='nettack')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test

# Setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# ---- 1. Train CLEAN model ----
model_clean = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16, device=device)
model_clean = model_clean.to(device)
model_clean.fit(features, adj, labels, idx_train)
output_clean = model_clean.predict()
acc_clean = accuracy(output_clean[idx_test], labels[idx_test])
print(f"[CLEAN]  Test accuracy: {acc_clean:.4f}")

# Apply CCFS attack
attacker = ClassConditionalFeatureAttack(model=model, nnodes=features.shape[0], device=device)
modified_features = attacker.attack(features, adj, labels, n_perturbations=100, swap_fraction=0.1)

# ---- 3. Train POISONED model ----
model_poisoned = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16, device=device)
model_poisoned = model_poisoned.to(device)
model_poisoned.fit(modified_features, adj, labels, idx_train)
output_poisoned = model_poisoned.predict()
acc_poisoned = accuracy(output_poisoned[idx_test], labels[idx_test])
print(f"[POISONED] Test accuracy: {acc_poisoned:.4f}")


Loading cora dataset...
Selecting 1 largest connected components
[CLEAN]  Test accuracy: 0.8310
[POISONED] Test accuracy: 0.8214
