In [7]:
# Reload modules automatically
# https://ipython.readthedocs.io/en/stable/config/extensions/autoreload.html
%load_ext autoreload
%autoreload 2
import logging
import math

import numpy as np
import torch
from strn_and_rbstness.attacks import Attack, create_attack
from strn_and_rbstness.data import GraphDataset, split
from strn_and_rbstness.helper.utils import accuracy, count_edges
from strn_and_rbstness.models import create_model
from strn_and_rbstness.train import _train
from common import CSBM

logger = logging.getLogger()
logger.setLevel(logging.INFO)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Sample CSBM

In [47]:
seed = 0
n = 1000
avg_intra_degree = 1.5 * 2 # intra_edges_per_node * 2
avg_inter_degree = 0.5 * 2
p = avg_intra_degree * 2 / (n - 1)
q = avg_inter_degree * 2 / (n - 1)
K = 0.5 # Defines distance between means of the gau√üians in sigma-units
sigma = 0.1
d = round(n / math.log(n)**2)
mu = np.array([K*sigma / (2 * d**0.5) for i in range(d)], dtype=np.float32)
cov = sigma**2 * np.identity(d, dtype=np.float32)

# X, A ~ CSBM(n, p, q, mu, cov)
csbm = CSBM(p, q, mu, cov)
X, A, y = csbm.sample(n, seed=0)
#csbm.check_separabilities(X, A, y)
print(f"Dim: {d}")

Dim: 21


### Model & Training Specifications

In [50]:
model_params = dict(
    label="GAT",
    model="GAT", #GCN or DenseGCN or APPNP
    dropout=0,
    dropout_neighourhood=0,
    n_heads=8
)

train_params = dict(
    loss_type="CE",
    lr=1e-2,
    weight_decay=1e-3,
    patience=300,
    max_epochs=1000,
    use_selftrain = False, 
    use_advtrain = False,
)

attack = 'PRBCD'
attack_params = dict(
    epochs=500,
    fine_tune_epochs=100,
    keep_heuristic="WeightOnly",
    search_space_size=1_000,
    do_synchronize=True,
    loss_type="tanhMargin"
)
epsilon = 0.05

# Other
split_params = {
    "strategy": "normal", # or "custom"
    "p_trn": 1,
    "p_tst": 0, # "normal" uses 1 - p_trn, only for custom split strategy
    "p_selftrn": 0 # Refers to unlabeled data, which is not test data, 
                    # only for custom split strategy
}
verbosity_params = dict(
    display_steps = 100
)   
# Device
device = 0
if not torch.cuda.is_available():
    device == "cpu", "CUDA is not availble, set device to 'cpu'"
else:
    device = torch.device(f"cuda:{device}")
    logging.info(f"Currently on gpu device {device}")
attack_params["data_device"] = device

INFO:root:Currently on gpu device cuda:0


### Train Model

In [51]:
torch.manual_seed(seed)
np.random.seed(seed)
split_ids = split(y, split_params, seed)
X_gpu = torch.tensor(X, dtype=torch.float32, device=device)
A_gpu = torch.tensor(A, dtype=torch.float32, device=device)
y_gpu = torch.tensor(y, device=device)
graph = GraphDataset((X_gpu, A_gpu, y_gpu), split_ids)
final_model_params = dict(**model_params, 
                          n_features=graph.get_n_features(), 
                          n_classes=graph.get_n_classes())
model = create_model(final_model_params).to(device)
statistics = _train(model, graph, train_params, verbosity_params, None)
#best_epoch = np.argmin(statistics[1])

INFO:root:
Epoch    0: loss_train: 0.69207, loss_val: 0.69447, acc_train: 0.52000, acc_val: 0.47200
INFO:root:
Epoch  100: loss_train: 0.65466, loss_val: 0.65814, acc_train: 0.67400, acc_val: 0.67800
INFO:root:
Epoch  200: loss_train: 0.61063, loss_val: 0.61687, acc_train: 0.67800, acc_val: 0.67600
INFO:root:
Epoch  300: loss_train: 0.59680, loss_val: 0.60792, acc_train: 0.68200, acc_val: 0.68600
INFO:root:
Epoch  400: loss_train: 0.59407, loss_val: 0.60764, acc_train: 0.68800, acc_val: 0.67400
INFO:root:
Epoch  500: loss_train: 0.59253, loss_val: 0.60677, acc_train: 0.69000, acc_val: 0.67400
INFO:root:
Epoch  600: loss_train: 0.58973, loss_val: 0.60553, acc_train: 0.70000, acc_val: 0.67400
INFO:root:
Epoch  700: loss_train: 0.58056, loss_val: 0.59861, acc_train: 0.71600, acc_val: 0.67400
INFO:root:
Epoch  800: loss_train: 0.55283, loss_val: 0.58366, acc_train: 0.74000, acc_val: 0.68400
INFO:root:
Epoch  900: loss_train: 0.51483, loss_val: 0.59068, acc_train: 0.78400, acc_val: 0.68600
