In [1]:
import sys
from itertools import count
from torch import autograd
import copy

sys.path.append('../')
from models.gcn import *
from utils.datasets import *

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = get_cora(device)

In [3]:
num_classes = torch.unique(data.y).shape[0]

In [4]:
def _sym(triu):
    triu = triu.triu(diagonal=1)
    return triu + triu.T

In [5]:
def greedy_metattack(
        flip_shape,
        budget,
        device,
        grad_fn,
        symmetric = True,
        *, flips_per_iteration = 1,
        max_iterations = None,
        progress = True
):
    flip = torch.zeros(flip_shape, device=device, requires_grad=True)
    used_budget = 0
    perts = []

    pbar = tqdm(total=budget, leave=False) if progress and max_iterations != 1 else None
    for _ in range(max_iterations) if max_iterations is not None else count():
        if symmetric:
            flip_sym = _sym(flip)
            grad = autograd.grad(flip_sym, flip, grad_outputs=grad_fn(flip_sym))[0]
        else:
            grad = grad_fn(flip)

        with torch.no_grad():
            # Note: If we wanted to maximize the loss, the != would be a ==, but as we want to minimize it, we have to
            # take the "opposite" gradient.
            grad[(grad < 0) != (flip == 0)] = 0
            flt = grad.abs().flatten()
            # Note: When we only look for one entry to flip, use max() instead of topk() as it's a lot faster.
            for v, linear_idx in [flt.max(dim=0)] if flips_per_iteration == 1 else zip(*flt.topk(flips_per_iteration)):
                if v == 0:
                    break
                linear_idx = linear_idx.item()
                idx_2d = (linear_idx // flip.shape[1], linear_idx % flip.shape[1])
                # Case 1: The edge has not been flipped previously.
                if flip[idx_2d] == 0:
                    flip[idx_2d] = 1
                    used_budget += 1
                    # If we have reached the next higher budget, save its perturbation and drop the budget.
                    if used_budget == budget:
                        break
                # Case 2: The edge has been flipped previously, so flip it back.
                else:
                    flip[idx_2d] = 0
                    used_budget -= 1
        if pbar:
            pbar.update(used_budget - pbar.n)
        # Stop if we have found perturbations for all budgets.
        if used_budget == budget:
            break
    if pbar:
        pbar.close()
    return perts

In [6]:
A = edge_index_to_A(data.edge_index, data.y.shape[0], device)

In [7]:
def make_model():
    return GCN(data.x.shape[1], 64, num_classes).to(device)
model = make_model()
model.fit(data)

Training Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

In [8]:
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [31]:
def loss_fn(A_flip):
    A_pert = A + A_flip * (1 - 2 * A)
    pert_edge_index = A_to_edge_index(A_pert)

    data_copy = data.clone() 
    data_copy.edge_index = pert_edge_index
    model = make_model()
    
    # model.fit(data_copy, differentiable=A_pert.requires_grad)
    model.fit(data_copy)

    scores = model(data_copy)
    # loss = -F.cross_entropy(scores[data_copy.test_mask, :], data_copy.y[data_copy.test_mask])
    # loss = loss.mean()
    loss = - F.cross_entropy(scores[data_copy.test_mask], data_copy.y[data_copy.test_mask], reduction='mean')
    print(f"Loss: {loss}, Shape: {loss.shape}")  # Should be a single scalar, shape should be torch.Size([]) for a scalar

    return loss

def grad_fn(A_flip):
    return torch.autograd.grad(loss_fn(A_flip), A_flip, allow_unused=True)[0]

In [32]:
budget = 300
pert = greedy_metattack(A.shape, budget, A.device, grad_fn)[0]

  0%|          | 0/300 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

Loss: -0.9727743864059448, Shape: torch.Size([])


RuntimeError: grad can be implicitly created only for scalar outputs