Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
YingtongDou committed Aug 28, 2023
1 parent fecec02 commit cb2706d
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 82 deletions.
151 changes: 129 additions & 22 deletions pygod/detector/gadnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
# Author: Yingtong Dou <ytongdou@gmail.com>
# License: BSD 2 clause

import os
import warnings
import time

import torch
import torch.nn.functional as F
from torch_geometric.loader import NeighborLoader
from torch_geometric import compile

from . import DeepDetector
from ..nn import GADNRBase
from ..utils import logger


class GADNR(DeepDetector):
Expand Down Expand Up @@ -44,7 +46,7 @@ class GADNR(DeepDetector):
dropout : float, optional
Dropout rate. Default: ``0.``.
weight_decay : float, optional
Weight decay (L2 penalty). Default: ``0.``.
Weight decay (L2 penalty). Default: ``0.0003``.
act : callable activation function or None, optional
Activation function if not None.
Default: ``torch.nn.functional.relu``.
Expand All @@ -59,7 +61,7 @@ class GADNR(DeepDetector):
the proportion of outliers in the dataset. Used when fitting to
define the threshold on the decision function. Default: ``0.1``.
lr : float, optional
Learning rate. Default: ``0.004``.
Learning rate. Default: ``0.01``.
epoch : int, optional
Maximum number of training epoch. Default: ``100``.
gpu : int
Expand Down Expand Up @@ -115,12 +117,12 @@ def __init__(self,
hid_s=4,
num_layers=4,
dropout=0.,
weight_decay=0.,
weight_decay=0.0003,
act=F.relu,
backbone=None,
alpha=0.5,
contamination=0.1,
lr=0.004,
lr=0.01,
epoch=100,
gpu=-1,
batch_size=0,
Expand Down Expand Up @@ -148,35 +150,140 @@ def __init__(self,
compile_model=compile_model,
**kwargs)

self.dim_s = None
self.neighbor_dict = None
self.neighbor_num_list = None
self.alpha = alpha
self.verbose = verbose
self.cache_dir = cache_dir

def process_graph(self, data):
GADNRBase.process_graph()
self.neighbor_dict, self.neighbor_num_list = \
GADNRBase.process_graph(data)
self.neighbor_num_list.to(self.device)

def init_model(self, **kwargs):
if self.save_emb:
self.emb = (torch.zeros(self.num_nodes, self.hid_dim[0]),
torch.zeros(self.num_nodes, self.hid_dim[1]))

return GADNRBase(in_dim, hid_dim, hid_dim, 2, sample_size, device=device,
neighbor_num_list=neighbor_num_list, GNN_name=encoder,
lambda_loss1=lambda_loss1, lambda_loss2=lambda_loss2,lambda_loss3=lambda_loss3)

# TODO update argument
return GADNRBase(in_dim, hid_dim, hid_dim, 2, sample_size,
device=device, neighbor_num_list=neighbor_num_list,
GNN_name=encoder, lambda_loss1=lambda_loss1,
lambda_loss2=lambda_loss2,
lambda_loss3=lambda_loss3).to(self.device)

def forward_model(self, data):

l1, h0 = self.model(data.x, data.edge_index)
h0, l1, degree_logits, feat_recon_list, neigh_recon_list = \
self.model(data.x, data.edge_index)

loss, loss_per_node, h_loss, degree_loss, feature_loss = \
self.model.loss_func(h0,
l1,
degree_logits,
feat_recon_list,
neigh_recon_list,
self.neighbor_num_list,
self.neighbor_dict)

return loss, loss_per_node.cpu().detach(), h_loss.cpu().detach(), \
degree_loss.cpu().detach(), feature_loss.cpu().detach()

# TODO update the fit function parmeters
def fit(self, data, label=None, real_loss=False):
"""
Overwrite the base fit function since GAD-NR use
multiple personalized loss functions.
"""

self.num_nodes, self.in_dim = data.x.shape
self.process_graph(data)
if self.batch_size == 0:
self.batch_size = data.x.shape[0]
loader = NeighborLoader(data,
self.num_neigh,
batch_size=self.batch_size)
self.model = self.init_model(**self.kwargs)
if self.compile_model:
self.model = compile(self.model)

losses = self.model.loss_func(l1,
ground_truth_degree_matrix,
h0,
neighbor_dict,
device,
data.x,
data.edge_index)
degree_params = list(map(id, self.model.degree_decoder.parameters()))
base_params = filter(lambda p: id(p) not in degree_params,
self.model.parameters())
optimizer = torch.optim.Adam([{'params': base_params},
{'params': self.model.degree_decoder.
parameters(), 'lr': 1e-2}],
lr=self.lr,
weight_decay=self.weight_decay)

min_loss = float('inf')

arg_min_loss_per_node = None

self.model.train()
self.decision_score_ = torch.zeros(data.x.shape[0])
for epoch in range(self.epoch):
start_time = time.time()
epoch_loss = 0
epoch_loss_per_node = torch.zeros(data.x.shape[0])
for sampled_data in loader:
batch_size = sampled_data.batch_size
node_idx = sampled_data.n_id

loss, loss_per_node, h_loss, degree_loss, feature_loss = \
self.forward_model(sampled_data)

h_loss_norm = h_loss / (torch.max(h_loss) - torch.min(h_loss))
degree_loss_norm = degree_loss / (torch.max(degree_loss) \
- torch.min(degree_loss))
feature_loss_norm = feature_loss / (torch.max(feature_loss) \
- torch.min(feature_loss))

comb_loss = h_loss_weight * h_loss_norm \
+ degree_loss_weight * degree_loss_norm \
+ feature_loss_weight * feature_loss_norm

if real_loss:
comp_loss = loss_per_node
else:
comp_loss = comb_loss

self.decision_score_[node_idx[:batch_size]] = comp_loss

# TODO update embedding computation
if self.save_emb:
if type(self.emb) == tuple:
self.emb[0][node_idx[:batch_size]] = \
self.model.emb[0][:batch_size].cpu()
self.emb[1][node_idx[:batch_size]] = \
self.model.emb[1][:batch_size].cpu()
else:
self.emb[node_idx[:batch_size]] = \
self.model.emb[:batch_size].cpu()

optimizer.zero_grad()
loss.backward()
optimizer.step()

epoch_loss += loss.item() * batch_size
epoch_loss_per_node[node_idx[:batch_size]] = loss_per_node

loss_value = epoch_loss / data.x.shape[0]

loss, loss_per_node, h_loss, degree_loss, feature_loss = losses
if loss_value < min_loss:
min_loss = loss_value
arg_min_loss_per_node = epoch_loss_per_node

logger(epoch=epoch,
loss=loss_value,
min_loss=min_loss,
arg_min_loss_per_node=arg_min_loss_per_node,
score=self.decision_score_,
target=label,
time=time.time() - start_time,
verbose=self.verbose,
train=True)

return loss, loss_per_node,h_loss,degree_loss,feature_loss
self._process_decision_score()
return self
2 changes: 2 additions & 0 deletions pygod/generator/outlier_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,5 @@ def gen_contextual_outlier(data, n, k, seed=None):
y_outlier[outlier_idx] = 1

return data, y_outlier

# TODO add new generator from GAD-NR

0 comments on commit cb2706d

Please sign in to comment.