In [None]:
import os
os.chdir('../')

In [None]:
%matplotlib inline
#%matplotlib notebook

%load_ext autoreload
%autoreload 2

In [None]:
from copy import deepcopy
from decimal import Decimal
from typing import List, Tuple
from warnings import warn

from cycler import cycler
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
import numpy as np
import pandas as pd
import scipy.sparse as sp
from sklearn.decomposition import PCA
import scipy.stats as stats
import torch
from torch import nn
import torch.nn.functional as F
import seml
from ogb.nodeproppred import PygNodePropPredDataset

import tqdm
tqdm.tqdm.pandas()

In [None]:
from notebooks import mpl_latex

In [None]:
#mpl_latex.enable_production_mode()

In [None]:
from rgnn_at_scale.data import prep_graph, split
from rgnn_at_scale.attacks import create_attack, SPARSE_ATTACKS
from rgnn_at_scale.io import Storage
from rgnn_at_scale.models import DenseGCN, GCN
from rgnn_at_scale.train import train
from rgnn_at_scale.utils import accuracy

In [None]:
dataset = 'ogbn-arxiv'
binary_attr = False
attack = 'PRBCD'
seed = 0
device = 0
surrogate_params = {
    'n_filters': [256, 256],
    'dropout': 0.5,
    'with_batchnorm': True,
    'train_params': {
        'lr': 1e-2,
        'weight_decay': 0,
        'patience': 100,
        'max_epochs': 3000
    }
}
attack_params = {
    'keep_heuristic': 'WeightOnly',
    'loss_type': 'tanhCW'
}
search_space_sizes = [50_000_000] #[350_000, 600_000, 1_000_000, 10_000_000, 50_000_000]
epsilon = 0.25
display_steps = 10

torch.manual_seed(seed)
np.random.seed(seed)

In [None]:
data = prep_graph(dataset, device='cpu', binary_attr=binary_attr, return_original_split=dataset.startswith('ogbn'))
if len(data) == 3:
    attr, adj, labels = data
    idx_train, idx_val, idx_test = data.split(labels.cpu().numpy())
else:
    attr, adj, labels, split = data
    idx_train, idx_val, idx_test = split['train'], split['valid'], split['test']
n_features = attr.shape[1]
n_classes = int(labels.max() + 1)

In [None]:
gcn = GCN(n_classes=n_classes, n_features=n_features, **surrogate_params).to(device)
train(model=gcn, attr=attr.to(device), adj=adj.to(device), labels=labels.to(device),
      idx_train=idx_train, idx_val=idx_val, display_step=display_steps, **surrogate_params['train_params'])

In [None]:
with torch.no_grad():
    pred_logits_surr = gcn(attr.to(device), adj.to(device))
accuracy(pred_logits_surr, labels.to(device), idx_test)

In [None]:
results = []
for search_space_size in search_space_sizes:
    temp_attack_params = dict(attack_params)
    temp_attack_params['search_space_size'] = search_space_size
    adversary = create_attack(attack, adj=adj, attr=attr, binary_attr=False, labels=labels,
                              model=gcn, idx_attack=idx_test, device=device, **temp_attack_params)

    m = adj._nnz() / 2
    torch.manual_seed(seed)
    np.random.seed(seed)

    n_perturbations = int(round(epsilon * m))
    adversary.attack(n_perturbations)
    
    results.append(adversary.attack_statistics)

In [None]:
fig, ax = mpl_latex.newfig(width=0.30, ratio_yx=0.6)
for result, search_space_size in zip(results, search_space_sizes):
    plt.plot(result['loss'], label=f'{search_space_size:.1E}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend(title='Block size $b$')

In [None]:
fig, ax = mpl_latex.newfig(width=0.30, ratio_yx=0.6)
for result, search_space_size in zip(results, search_space_sizes):
    plt.plot(result['accuracy'], label=f'{search_space_size:.1E}')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend(title='Block size $b$')

In [None]:
fig, ax = mpl_latex.newfig(width=0.30, ratio_yx=0.6)
for result, search_space_size in zip(results, search_space_sizes):
    plt.plot(result['nonzero_weights'], label=search_space_size)
    plt.xlabel('Epochs')
    plt.ylabel('non-zero weights')

In [None]:
list(results[0].keys())