In [1]:
# Reload modules automatically
# https://ipython.readthedocs.io/en/stable/config/extensions/autoreload.html
%load_ext autoreload
%autoreload 2

In [2]:
import importlib
from src.graph_models import create_graph_model
import numpy as np
import torch
import exp_eval_robustness
import exp_train
from torch_sparse import SparseTensor



In [37]:
data_params = dict(
    graph_model = 'CSBM',
    classes = 2,
    n = 1000,
    n_per_class_trn = 400,
    K = 5,
    sigma = 1,
    avg_within_class_degree = 1.58 * 2,
    avg_between_class_degree = 0.37 * 2,
    inductive_samples = 10
)

model_params = dict(
    label="LinearGCN",
    model="LinearGCN",
    n_filter=64,
    dropout=0.5,
    use_label_propagation=False,
)

train_params = dict(
    lr=0.1,
    weight_decay=1e-4,
    patience=300,
    max_epochs=3000,
    inductive=True,
)

attack_params = dict(
    attack = "nettack"
)

verbosity_params = dict(
    display_steps = 100,
    debug_lvl = "info"
)  

other_params = dict(
    device = 0,
    allow_tf32 = False,
    sacred_metrics = True
)

seed = 0

In [38]:
importlib.reload(exp_eval_robustness)
importlib.reload(exp_train)
#result = exp_train.run(data_params, model_params, train_params, verbosity_params, other_params, seed, None)
result = exp_eval_robustness.run(data_params, model_params, train_params, attack_params, verbosity_params, other_params, seed, None)

2022-09-06 17:09:25 (INFO): Starting experiment exp_eval_robustness with configuration:
2022-09-06 17:09:25 (INFO): data_params: {'graph_model': 'CSBM', 'classes': 2, 'n': 1000, 'n_per_class_trn': 400, 'K': 5, 'sigma': 1, 'avg_within_class_degree': 3.16, 'avg_between_class_degree': 0.74, 'inductive_samples': 10}
2022-09-06 17:09:25 (INFO): model_params: {'label': 'LinearGCN', 'model': 'LinearGCN', 'n_filter': 64, 'dropout': 0.5, 'use_label_propagation': False}
2022-09-06 17:09:25 (INFO): train_params: {'lr': 0.1, 'weight_decay': 0.0001, 'patience': 300, 'max_epochs': 3000, 'inductive': True}
2022-09-06 17:09:25 (INFO): attack_params: {'attack': 'nettack'}
2022-09-06 17:09:25 (INFO): verbosity_params: {'display_steps': 100, 'debug_lvl': 'info'}
2022-09-06 17:09:25 (INFO): other_params: {'device': 0, 'allow_tf32': False, 'sacred_metrics': True}
2022-09-06 17:09:25 (INFO): seed: 0
2022-09-06 17:09:25 (INFO): db_collection: None
2022-09-06 17:09:25 (INFO): Currently on gpu device cuda:0
20

Deg: 4; true_class: 0
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 9.8170
Loss after attack: 4.8922
Edge Added: (1000, 211); Classes: (0, 1), Feature Distance: 10.105
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 4.8922
Loss after attack: 1.9262
Edge Added: (1000, 1); Classes: (0, 1), Feature Distance: 10.354
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 0.7578
Loss after attack: -1.5035
Edge Added: (1000, 648); Classes: (0, 1), Feature Distance: 8.282
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -1.5035
Loss after attack: -3.1694
Edge Added: (1000, 481); Classes: (0, 1), Feature Distance: 9.000
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -3.9182
Loss after attack: -5.3208
Edge Added: (1000, 550); Classes: (0, 1), Feature Distance: 8.198
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -5.3208
Loss after attack: -6.5397
Edge Removed: (1000, 400); Classes: (0, 0), Feature Distance: 7.239
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -8.0998
Loss after attack: -9.4322
Edge Removed: (1000, 979)

 10%|█         | 1/10 [00:00<00:07,  1.15it/s]

Loss after attack: -14.6832
Edge Removed: (1000, 958); Classes: (0, 0), Feature Distance: 5.279
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -15.2398
Loss after attack: -15.7202
Edge Added: (1000, 476); Classes: (0, 1), Feature Distance: 9.594
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -16.4698
Loss after attack: -16.8458
Edge Added: (1000, 290); Classes: (0, 1), Feature Distance: 9.289
Deg: 4; true_class: 0
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 13.3086
Loss after attack: 8.4497
Edge Added: (1000, 211); Classes: (0, 1), Feature Distance: 8.978
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 8.4497
Loss after attack: 5.3550
Edge Added: (1000, 1); Classes: (0, 1), Feature Distance: 10.159
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 4.1865
Loss after attack: 2.0004
Edge Added: (1000, 648); Classes: (0, 1), Feature Distance: 9.081
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 2.0004
Loss after attack: -0.0019
Edge Removed: (1000, 699); Classes: (0, 0), Feature Distance: 5.799
Bayes_s

 20%|██        | 2/10 [00:02<00:08,  1.08s/it]

Loss after attack: -16.3322
Edge Added: (1000, 760); Classes: (0, 1), Feature Distance: 9.058
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -16.6568
Loss after attack: -16.8495
Edge Added: (1000, 560); Classes: (0, 1), Feature Distance: 8.154
Deg: 5; true_class: 0
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 5.1339
Loss after attack: 1.7003
Edge Added: (1000, 211); Classes: (0, 1), Feature Distance: 9.388
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 1.7003
Loss after attack: -0.4587
Edge Added: (1000, 1); Classes: (0, 1), Feature Distance: 9.767
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -1.5516
Loss after attack: -3.1696
Edge Added: (1000, 648); Classes: (0, 1), Feature Distance: 7.839
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -3.1696
Loss after attack: -4.5816
Edge Added: (1000, 319); Classes: (0, 1), Feature Distance: 10.680
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -3.5235
Loss after attack: -4.8191
Edge Removed: (1000, 227); Classes: (0, 0), Feature Distance: 5.559
Bayes_se

 30%|███       | 3/10 [00:03<00:09,  1.35s/it]

Loss after attack: -17.0667
Edge Added: (1000, 331); Classes: (0, 1), Feature Distance: 9.053
Deg: 4; true_class: 1
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 3.7314
Loss after attack: -0.2829
Edge Added: (1000, 516); Classes: (1, 0), Feature Distance: 9.050
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -0.2829
Loss after attack: -2.8875
Edge Added: (1000, 17); Classes: (1, 0), Feature Distance: 8.088
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -2.8875
Loss after attack: -5.0102
Edge Added: (1000, 751); Classes: (1, 0), Feature Distance: 8.228
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -5.0102
Loss after attack: -6.8999
Edge Removed: (1000, 331); Classes: (1, 1), Feature Distance: 6.937
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -6.5384
Loss after attack: -8.2982
Edge Removed: (1000, 541); Classes: (1, 1), Feature Distance: 6.459
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -9.1444


 40%|████      | 4/10 [00:04<00:06,  1.08s/it]

Loss after attack: -10.6035
Edge Removed: (1000, 127); Classes: (1, 1), Feature Distance: 7.009
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -11.3461
Loss after attack: -12.5388
Edge Added: (1000, 762); Classes: (1, 0), Feature Distance: 8.791
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -12.8758
Loss after attack: -13.8302
Edge Added: (1000, 960); Classes: (1, 0), Feature Distance: 9.247
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -14.6620
Loss after attack: -15.4963
Edge Added: (1000, 915); Classes: (1, 0), Feature Distance: 7.824
Deg: 3; true_class: 1
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 10.1688
Loss after attack: 5.1610
Edge Added: (1000, 516); Classes: (1, 0), Feature Distance: 9.882
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 5.1610
Loss after attack: 2.1128
Edge Added: (1000, 17); Classes: (1, 0), Feature Distance: 7.972
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 2.1128
Loss after attack: -0.2668
Edge Added: (1000, 751); Classes: (1, 0), Feature Distance: 8.266
Bayes

 50%|█████     | 5/10 [00:05<00:05,  1.14s/it]

Loss after attack: -14.8211
Edge Added: (1000, 114); Classes: (1, 0), Feature Distance: 7.698
Deg: 7; true_class: 1
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 5.7537
Loss after attack: 3.0241
Edge Added: (1000, 516); Classes: (1, 0), Feature Distance: 8.451
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 3.0241
Loss after attack: 1.0730
Edge Added: (1000, 17); Classes: (1, 0), Feature Distance: 8.517
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 1.0730
Loss after attack: -0.6357
Edge Added: (1000, 751); Classes: (1, 0), Feature Distance: 8.669
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -0.6357
Loss after attack: -2.0428
Edge Added: (1000, 920); Classes: (1, 0), Feature Distance: 8.958
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -1.4427
Loss after attack: -2.7154
Edge Added: (1000, 762); Classes: (1, 0), Feature Distance: 9.568
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -2.9444
Loss after attack: -4.2176
Edge Added: (1000, 972); Classes: (1, 0), Feature Distance: 8.922
Bayes_sep: 1; 

 60%|██████    | 6/10 [00:07<00:04,  1.25s/it]

Loss after attack: -9.3888
Edge Added: (1000, 798); Classes: (1, 0), Feature Distance: 8.845
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -9.9606
Loss after attack: -10.7347
Edge Removed: (1000, 9); Classes: (1, 1), Feature Distance: 5.277
Deg: 4; true_class: 0
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 6.1674
Loss after attack: 2.2967
Edge Added: (1000, 211); Classes: (0, 1), Feature Distance: 9.294
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 2.2967
Loss after attack: -0.0979
Edge Added: (1000, 1); Classes: (0, 1), Feature Distance: 9.977
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -1.2663
Loss after attack: -2.9216
Edge Added: (1000, 648); Classes: (0, 1), Feature Distance: 8.427
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -2.9216
Loss after attack: -4.2877
Edge Removed: (1000, 297); Classes: (0, 0), Feature Distance: 6.942
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -5.8126
Loss after attack: -7.1901
Edge Removed: (1000, 630); Classes: (0, 0), Feature Distance: 7.844
Bayes_sep

 70%|███████   | 7/10 [00:08<00:03,  1.22s/it]

Loss after attack: -14.9520
Edge Added: (1000, 938); Classes: (0, 1), Feature Distance: 9.051
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -15.1396
Loss after attack: -15.3771
Edge Added: (1000, 731); Classes: (0, 1), Feature Distance: 9.241
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -16.0676
Loss after attack: -16.2666
Edge Added: (1000, 760); Classes: (0, 1), Feature Distance: 8.632
Deg: 3; true_class: 1
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 9.9540
Loss after attack: 4.8025
Edge Added: (1000, 516); Classes: (1, 0), Feature Distance: 9.278
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 4.8025
Loss after attack: 1.6606
Edge Added: (1000, 17); Classes: (1, 0), Feature Distance: 7.620
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 1.6606
Loss after attack: -0.7846
Edge Added: (1000, 751); Classes: (1, 0), Feature Distance: 9.435
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -0.7846
Loss after attack: -2.6505
Edge Added: (1000, 762); Classes: (1, 0), Feature Distance: 9.063
Bayes_sep:

 80%|████████  | 8/10 [00:09<00:02,  1.20s/it]

Loss after attack: -7.7949
Edge Removed: (1000, 906); Classes: (1, 1), Feature Distance: 7.666
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -7.8627
Loss after attack: -8.8665
Edge Added: (1000, 114); Classes: (1, 0), Feature Distance: 7.835
Deg: 2; true_class: 0
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 6.5250
Loss after attack: 0.4360
Edge Added: (1000, 211); Classes: (0, 1), Feature Distance: 9.548
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 0.4360
Loss after attack: -2.7381
Edge Added: (1000, 1); Classes: (0, 1), Feature Distance: 9.677
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -4.1207
Loss after attack: -5.9409
Edge Removed: (1000, 615); Classes: (0, 0), Feature Distance: 6.440
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -6.0678
Loss after attack: -7.8568
Edge Removed: (1000, 378); Classes: (0, 0), Feature Distance: 6.920
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -9.1746
Loss after attack: -10.3969
Edge Added: (1000, 481); Classes: (0, 1), Feature Distance: 9.176
Bayes_s

 90%|█████████ | 9/10 [00:10<00:01,  1.13s/it]

Loss after attack: -17.1923
Edge Added: (1000, 560); Classes: (0, 1), Feature Distance: 7.940
Deg: 1; true_class: 0
Bayes_sep: 1; GNN_sep: 1
Loss before attack: 8.2371
Loss after attack: -0.6896
Edge Added: (1000, 211); Classes: (0, 1), Feature Distance: 10.419
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -0.6896
Loss after attack: -4.3977
Edge Added: (1000, 1); Classes: (0, 1), Feature Distance: 10.350
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -5.9434
Loss after attack: -8.6535
Edge Removed: (1000, 993); Classes: (0, 0), Feature Distance: 8.538
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -10.4017
Loss after attack: -11.5174
Edge Added: (1000, 648); Classes: (0, 1), Feature Distance: 7.861
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -11.5174
Loss after attack: -12.3804
Edge Added: (1000, 481); Classes: (0, 1), Feature Distance: 9.528
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -13.3849
Loss after attack: -13.8775
Edge Added: (1000, 476); Classes: (0, 1), Feature Distance: 10.093


100%|██████████| 10/10 [00:11<00:00,  1.10s/it]
2022-09-06 17:09:38 (INFO): Prediction Statistics:
2022-09-06 17:09:38 (INFO): Count BC: 10.0 GNN: 10.0
2022-09-06 17:09:38 (INFO): Count Structure BC: 10.0 Feature BC: 10.0
2022-09-06 17:09:38 (INFO): Count BC and GNN: 10.0
2022-09-06 17:09:38 (INFO): Count BC not GNN: 0.0 GNN not BC: 0.0
2022-09-06 17:09:38 (INFO): Robustness Statistics:
2022-09-06 17:09:38 (INFO): BC more robust than GNN: 10.0
2022-09-06 17:09:38 (INFO): BC & GNN equal robustness: 0.0
2022-09-06 17:09:38 (INFO): BC less robust than GNN: 0.0
2022-09-06 17:09:38 (INFO): Degree 0: <BC robust>: -1.00; <GNN robust>: -1.00; 
2022-09-06 17:09:38 (INFO): Degree 1: <BC robust>: 8.00; <GNN robust>: 0.00; 
2022-09-06 17:09:38 (INFO): Degree 2: <BC robust>: 14.00; <GNN robust>: 1.00; 
2022-09-06 17:09:38 (INFO): Degree 3: <BC robust>: 12.00; <GNN robust>: 2.00; 
2022-09-06 17:09:38 (INFO): Degree 4: <BC robust>: 12.25; <GNN robust>: 1.50; 
2022-09-06 17:09:38 (INFO): Degree 5: <BC

Loss after attack: -15.0479
Edge Added: (1000, 290); Classes: (0, 1), Feature Distance: 10.529
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -15.8574
Loss after attack: -16.1991
Edge Added: (1000, 550); Classes: (0, 1), Feature Distance: 7.680
Bayes_sep: 1; GNN_sep: 0
Loss before attack: -16.1991
Loss after attack: -16.4437
Edge Added: (1000, 938); Classes: (0, 1), Feature Distance: 9.143


In [12]:
from src.data import split
from src.models import create_model
from src.train import train_inductive, train_transductive

model_params = dict(
    label="GCN",
    model="LinearGCN",
    n_filter=64,
    dropout=0.5,
    use_label_propagation=False,
)
train_params = dict(
    lr=0.1,
    weight_decay=1e-3,
    patience=300,
    max_epochs=3000,
    inductive=True,
)
_run = None
device = torch.device(f"cuda:{0}")

# Sample Graph
graph_model = create_graph_model(data_params)
X_np, A_np, y_np = graph_model.sample(data_params["n"], seed)
X = torch.tensor(X_np, dtype=torch.float32, device=device)
A = torch.tensor(A_np, dtype=torch.float32, device=device)
y = torch.tensor(y_np, device=device)
split_trn, split_val = split(y_np, data_params)

# Create Model
model_params_trn = dict(**model_params, 
                        n_features=X_np.shape[1], 
                        n_classes=data_params["classes"])
model = create_model(model_params_trn).to(device)
#logging.info(model)

# Train Model
if train_params["inductive"]:
    train = train_inductive
else:
    train = train_transductive
trn_tracker = train(model, None, X, A, y, split_trn, split_val, train_params,
                    verbosity_params, _run)

2022-09-06 14:57:49 (INFO): 
Epoch    0: loss_train: 0.70049, loss_val: 0.69125, acc_train: 0.51250, acc_val: 0.54000
2022-09-06 14:57:49 (INFO): 
Epoch  100: loss_train: 0.66957, loss_val: 0.70510, acc_train: 0.59875, acc_val: 0.55000
2022-09-06 14:57:50 (INFO): 
Epoch  200: loss_train: 0.69992, loss_val: 0.69783, acc_train: 0.57000, acc_val: 0.59000
2022-09-06 14:57:50 (INFO): 
Epoch  300: loss_train: 0.67012, loss_val: 0.70146, acc_train: 0.59500, acc_val: 0.53000
2022-09-06 14:57:51 (INFO): 
Epoch  400: loss_train: 0.67534, loss_val: 0.70552, acc_train: 0.57750, acc_val: 0.53000
2022-09-06 14:57:51 (INFO): 
Epoch  500: loss_train: 0.69036, loss_val: 0.70550, acc_train: 0.58875, acc_val: 0.53500
2022-09-06 14:57:51 (INFO): 
Epoch  600: loss_train: 0.67120, loss_val: 0.70989, acc_train: 0.60250, acc_val: 0.53000
2022-09-06 14:57:52 (INFO): 
Epoch  700: loss_train: 0.68360, loss_val: 0.70071, acc_train: 0.56875, acc_val: 0.53500
2022-09-06 14:57:52 (INFO): 
Epoch  430: loss_train: 0.6

In [16]:
type(model)

src.models.gcn.DenseGCN

In [18]:
from src.models.gcn import DenseGCN
isinstance(model, DenseGCN)

True

In [22]:
isinstance(model.activation, torch.nn.Identity)

True

In [25]:
model.layers[0][0]._linear.weight

Parameter containing:
tensor([[-0.2845, -0.1751,  0.0541,  ...,  0.3195,  0.0939,  0.0529],
        [-0.0133, -0.1708,  0.0041,  ..., -0.0677, -0.3852, -0.0863],
        [ 0.0237, -0.1400,  0.0761,  ..., -0.0926, -0.1117, -0.0352],
        ...,
        [-0.1547, -0.6180,  0.0105,  ..., -0.4443, -0.1026, -0.1262],
        [-0.0155,  0.1586,  0.4082,  ...,  0.2353, -0.1754, -0.0009],
        [-0.1066, -0.2037, -0.0348,  ..., -0.1079, -0.1091, -0.2001]],
       device='cuda:0', requires_grad=True)

In [26]:
model_params = dict(
    label="SGC",
    model="SGC",
    K=2,
    use_label_propagation=False,
)
model_params_trn = dict(**model_params, 
                        n_features=X_np.shape[1], 
                        n_classes=data_params["classes"])
sgc = create_model(model_params_trn).to(device)

In [45]:
model.layers[0][0]

DenseGraphConvolution(
  (_linear): Linear(in_features=21, out_features=64, bias=False)
)

In [39]:
sgc.parameters

<bound method Module.parameters of SGC(
  (sgc): SGConv(21, 2, K=2)
)>

In [15]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]

KeyError: 'CUBLAS_WORKSPACE_CONFIG'

In [9]:
data_params["sigma"] = 1
seed = 1

result = exp_eval_robustness.run(data_params, model_params, train_params, attack_params, verbosity_params, other_params, seed, None)

2022-08-10 14:11:26 (INFO): Starting experiment exp_eval_robustness with configuration:
2022-08-10 14:11:26 (INFO): data_params: {'graph_model': 'CSBM', 'classes': 2, 'n': 1000, 'n_per_class_trn': 400, 'K': 0.5, 'sigma': 1, 'avg_within_class_degree': 3.0, 'avg_between_class_degree': 1.0, 'inductive_samples': 1000}
2022-08-10 14:11:26 (INFO): model_params: {'label': 'GCN', 'model': 'DenseGCN', 'n_filters': 64}
2022-08-10 14:11:26 (INFO): train_params: {'lr': 0.01, 'weight_decay': 0.001, 'patience': 300, 'max_epochs': 3000, 'inductive': True}
2022-08-10 14:11:26 (INFO): attack_params: {'attack': 'l2'}
2022-08-10 14:11:26 (INFO): verbosity_params: {'display_steps': 100, 'debug_lvl': 'info'}
2022-08-10 14:11:26 (INFO): other_params: {'device': 0, 'allow_tf32': False, 'sacred_metrics': True}
2022-08-10 14:11:26 (INFO): seed: 1
2022-08-10 14:11:26 (INFO): db_collection: None
2022-08-10 14:11:26 (INFO): Currently on gpu device cuda:0
2022-08-10 14:11:27 (INFO): 
Epoch    0: loss_train: 0.7010

In [21]:
from src.data import calc_balanced_sample

In [22]:
class_counts = [10, 100, 100]
n_samples = 60
print(calc_balanced_sample(class_counts, n_samples))

[10 25 25]


In [35]:
from src.utils import accuracy

In [7]:
logits = torch.Tensor([[0.8, 0.2], [0.1, 0.9]])
labels = torch.Tensor([0, 1])
logits.argmax(1)[1]

tensor(1)