In [1]:
# Reload modules automatically
# https://ipython.readthedocs.io/en/stable/config/extensions/autoreload.html
%load_ext autoreload
%autoreload 2

In [56]:
import logging
import math
import sys

import numpy as np
import torch
from strn_and_rbstness.attacks import Attack, create_attack
from strn_and_rbstness.data import GraphDataset, split
from strn_and_rbstness.helper.utils import accuracy, count_edges
from strn_and_rbstness.models import create_model
from strn_and_rbstness.train import _train
from common import CSBM

logger = logging.getLogger()
logger.setLevel(logging.INFO)

## Sample CSBM

In [39]:
seed = 0
n = 5000
avg_intra_degree = 1.5 * 2 # intra_edges_per_node * 2
avg_inter_degree = 0.5 * 2
p = avg_intra_degree * 2 / (n - 1)
q = avg_inter_degree * 2 / (n - 1)
K = 0.5 # Defines distance between means of the gaußians in sigma-units
sigma = 0.1
d = round(n / math.log(n)**2)
mu = np.array([K*sigma / (2 * d**0.5) for i in range(d)], dtype=np.float32)
cov = sigma**2 * np.identity(d, dtype=np.float32)

# X, A ~ CSBM(n, p, q, mu, cov)
csbm = CSBM(p, q, mu, cov)
X, A, y = csbm.sample(n, seed=0)
#csbm.check_separabilities(X, A, y)
print(f"Dim: {d}")

Feature Separability:
n_corr: 1211
n_wrong: 789
Structure Separability:
n_corr: 1665
n_wrong: 335
Likelihood Separability:
n_corr: 1691
n_wrong: 309
Dim: 35


## Configure GNN & Attack

In [65]:
model_params = dict(
    label="GCN",
    model="DenseGCN",
    n_filters=64,
)
train_params = dict(
    loss_type="CE",
    lr=1e-2,
    weight_decay=1e-3,
    patience=300,
    max_epochs=3000,
    use_selftrain = False, 
    use_advtrain = False,
)
attack = "PGD"
attack_params = dict(
    epochs=200,
    base_lr=1e-2,
    scale_lr_with_n_attacked_edges=True,
    loss_type="tanhMargin" # or tanhMargin or CW
)
epsilon = 0.05

# Other
split_params = {
    "strategy": "normal", # or "custom"
    "p_trn": 1,
    "p_tst": 0, # "normal" uses 1 - p_trn, only for custom split strategy
    "p_selftrn": 0 # Refers to unlabeled data, which is not test data, 
                    # only for custom split strategy
}
verbosity_params = dict(
    display_steps = 100
)   
# Device
device = 0
if not torch.cuda.is_available():
    device == "cpu", "CUDA is not availble, set device to 'cpu'"
else:
    device = torch.device(f"cuda:{device}")
    logging.info(f"Currently on gpu device {device}")

INFO:root:Currently on gpu device cuda:0


#### Train GNN

In [66]:
torch.manual_seed(seed)
np.random.seed(seed)
split_ids = split(y, split_params, seed)
X_gpu = torch.tensor(X, dtype=torch.float32, device=device)
A_gpu = torch.tensor(A, dtype=torch.float32, device=device)
y_gpu = torch.tensor(y, device=device)
graph = GraphDataset((X_gpu, A_gpu, y_gpu), split_ids)
model_params = dict(**model_params, 
                    n_features=graph.get_n_features(), 
                    n_classes=graph.get_n_classes())
model = create_model(model_params).to(device)
statistics = _train(model, graph, train_params, verbosity_params, None)
best_epoch = np.argmin(statistics[1])

INFO:root:
Epoch    0: loss_train: 0.69319, loss_val: 0.69304, acc_train: 0.51051, acc_val: 0.49251
INFO:root:
Epoch  100: loss_train: 0.56284, loss_val: 0.61014, acc_train: 0.73073, acc_val: 0.67333
INFO:root:
Epoch  200: loss_train: 0.52982, loss_val: 0.60279, acc_train: 0.76476, acc_val: 0.68531
INFO:root:
Epoch  300: loss_train: 0.51179, loss_val: 0.60599, acc_train: 0.78278, acc_val: 0.68432
INFO:root:
Epoch  400: loss_train: 0.51493, loss_val: 0.60102, acc_train: 0.77277, acc_val: 0.70330
INFO:root:
Epoch  500: loss_train: 0.51157, loss_val: 0.60137, acc_train: 0.78078, acc_val: 0.70130
INFO:root:
Epoch  600: loss_train: 0.51748, loss_val: 0.60674, acc_train: 0.77778, acc_val: 0.68831
INFO:root:
Epoch  700: loss_train: 0.51202, loss_val: 0.59716, acc_train: 0.76877, acc_val: 0.69131
INFO:root:
Epoch  800: loss_train: 0.51189, loss_val: 0.59494, acc_train: 0.78979, acc_val: 0.69630
INFO:root:
Epoch  584: loss_train: 0.51144, loss_val: 0.58559, acc_train: 0.78278, acc_val: 0.70829


In [78]:
idx_all = np.arange(len(y))
model.eval()
logits = model(X_gpu, A_gpu)
acc_trn = accuracy(logits, y_gpu, split_ids[1])
print(acc_trn)

0.6973026990890503


#### Attack GNN

In [55]:
adversary = create_attack(attack, attr=X_gpu, adj=A_gpu, labels=y_gpu, 
                            model=model, idx_attack=idx_all, device=device, 
                            binary_attr=False,
                            make_undirected=True, 
                            **attack_params)
m  = count_edges(A_gpu, idx_all)
n_perturbations = int(round(epsilon * m))
print(f"#Edges: {m} -> budget: {n_perturbations}")
adversary.attack(n_perturbations, _run=None)
A_pert, X_pert = adversary.get_pertubations()
logits, accuracy = Attack.evaluate_global(model, X_pert, A_pert, y_gpu, idx_all)
print(f"Accuracy: {accuracy}")

#Edges: 4026 -> budget: 201
Accuracy: 0.6485000252723694


In [53]:
A_pert

SparseTensor(row=tensor([   0,    1,    1,  ..., 1999, 1999, 1999], device='cuda:0'),
             col=tensor([1033,  369,  929,  ..., 1334, 1413, 1869], device='cuda:0'),
             val=tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0'),
             size=(2000, 2000), nnz=8444, density=0.21%)

#### Eval Separability

In [44]:
csbm.check_separabilities(X, A, y)

Feature Separability:
n_corr: 594
n_wrong: 405
Structure Separability:
n_corr: 725
n_wrong: 274
Likelihood Separability:
n_corr: 754
n_wrong: 245


In [48]:
print(np.sort(np.array(split_ids[0])))
print(np.sort(np.array(split_ids[1])))

[   2    3    6    7    8   11   13   15   16   18   21   22   23   26
   29   30   32   33   34   36   40   41   42   44   45   46   50   51
   52   55   56   57   62   66   69   70   72   77   78   79   82   84
   85   86   91   92   93   94   97   98  102  103  106  107  108  109
  111  112  113  115  116  117  118  120  123  124  127  131  132  134
  137  139  141  142  143  144  148  150  153  154  156  157  159  163
  164  165  166  167  168  169  171  172  174  175  176  180  181  184
  188  190  191  198  200  202  203  204  208  209  210  212  215  216
  217  222  229  231  232  234  238  242  243  244  245  246  248  249
  250  251  252  254  255  258  265  267  268  269  270  271  272  273
  278  284  285  286  287  288  289  291  293  294  295  298  299  300
  302  304  305  307  310  314  316  317  319  320  321  322  324  325
  326  328  330  331  332  333  335  337  338  339  341  344  345  348
  349  353  356  358  359  361  362  363  364  366  368  370  373  381
  382 