<a href="https://colab.research.google.com/github/sawlani/GNN-anomaly/blob/master/GraphDeepSVDD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence

import argparse
import numpy as np
import networkx as nx
from tqdm.notebook import tqdm
from sklearn.metrics import average_precision_score, auc
import matplotlib.pyplot as plt

import pickle

In [None]:
class S2VGraph(object):
    def __init__(self, g, label, node_tags=None, node_features=None):
        '''
            g: a networkx graph
            label: an integer graph label
            node_tags: a list of integer node tags
            node_features: a torch float tensor, one-hot representation of the tag that is used as input to neural nets
            edge_mat: a torch long tensor, contain edge list, will be used to create torch sparse tensor
            neighbors: list of neighbors (without self-loop)
        '''
        self.label = label
        self.g = g
        self.node_tags = node_tags
        self.node_features = node_features # one-hot encoded node-tags
        self.edge_mat = None

        # edge_mat is of the form:
        # [[u1 u2 u3 ... um]
        #  [v1 v2 v3 ... vm]]

In [None]:
def random_split_counts(total_number, no_of_splits):
    split_indices = np.sort(np.random.choice(total_number,no_of_splits-1, replace = False)+1)
    split_indices = np.insert(split_indices, 0, 0)
    split_indices = np.append(split_indices, total_number)
    split_counts = np.diff(split_indices)
    return split_counts

In [None]:
random_state = np.random.RandomState() #type: np.random.RandomState

In [None]:
class GraphGenerator:
	def __init__(self, numClass):
		self.numClass = numClass

	def format_path(self, G, savePath, graphName, **kwargs):
		graphName = graphName.format(numNode=G.number_of_nodes(), numEdge=G.number_of_edges(), numClass=self.numClass, **kwargs)
		savePath = savePath.format(graphName=graphName, numNode=G.number_of_nodes(), numEdge=G.number_of_edges(), numClass=self.numClass, **kwargs)
		return savePath, graphName

	def save_graph(self, G: nx.Graph, savePath, graphName, **kwargs):
		savePath, graphName = self.format_path(G, savePath, graphName, **kwargs)
		G_d = nx.to_dict_of_lists(G)
		print("Saving graph to {}".format(os.path.join(savePath, graphName + ".graph")))
		pickle.dump(G_d, open(os.path.join(savePath, graphName + ".graph"), "wb"))
	
	def save_y(self, G: nx.Graph, savePath, graphName, **kwargs):
		savePath, graphName = self.format_path(G, savePath, graphName, **kwargs)
		ally = np.zeros((len(G.nodes()), self.numClass))
		for v in G.nodes():
			ally[v][G.nodes[v]['color'] - 1] = 1
		print("Saving labels to {}".format(os.path.join(savePath, graphName + ".ally")))
		pickle.dump(ally, open(os.path.join(savePath, graphName + ".ally"), "wb"))
	
	def save_nx_graph(self, G: nx.Graph, savePath, graphName, **kwargs):
		savePath, graphName = self.format_path(G, savePath, graphName, **kwargs)
		print("Pickling networkx graph to {}".format(os.path.join(savePath, graphName + ".gpickle.gz")))
		nx.write_gpickle(G, os.path.join(savePath, graphName + ".gpickle.gz"))

In [None]:
class MixhopGraphGenerator(GraphGenerator):
	def get_color(self, class_ratio): # Assign new node to a class
		if self.__coloriter:
			return next(self.__coloriter)
		else:
			return np.random.choice(list(range(1, len(class_ratio) + 1)), 1, False, class_ratio)[0]

	def color_weight(self, col1, col2):
		dist = abs(col1 - col2)
		dist = min(dist, len(self.classRatio) - dist)
		return self.heteroWeightsDict[dist]
	
	def get_neighbors(self, G, m, col, h):
		pr = dict()
		for v in G.nodes():
			degree_v = float(max(G.degree[v], 1)) # Degree is treated as at least 1, so that pr[:] is not all 0
			if G.nodes[v]['color'] == col:
				pr[v] = float(degree_v) * h
			else:
				pr[v] = float(degree_v) * ((1 - h) * self.color_weight(col, G.nodes[v]['color']))

		norm_pr = float(sum(pr.values()))
		if norm_pr == 0:
			return None
		else:
			for v in list(pr.keys()):
				pr[v] = float(pr[v]) / norm_pr
			neighbors = np.random.choice(list(pr.keys()), m, False, list(pr.values()))
			return neighbors

	def __init__(self, classRatio, heteroClsWeight="circularDist", **kwargs):
		super().__init__(len(classRatio))
		self.classRatio = classRatio
		self.heteroWeightsDict = dict()
		
		if heteroClsWeight == "circularDist":
			for i in range(2, self.numClass + 1):
				circularDist = min(i - 1, self.numClass - (i - 1))
				self.heteroWeightsDict[circularDist] = self.heteroWeightsDict.get(circularDist, 0) + 1

			maxDist = max(self.heteroWeightsDict.keys())
			weightSum = 0
			for dist, times in self.heteroWeightsDict.items():
				self.heteroWeightsDict[dist] = kwargs["heteroWeightsExponent"] ** (maxDist - dist)
				weightSum += self.heteroWeightsDict[dist] * times
			self.heteroWeightsDict = {dist: weight / weightSum for dist, weight in self.heteroWeightsDict.items()}
		
		elif heteroClsWeight == "uniform":
			for i in range(2, self.numClass + 1):
				circularDist = min(i - 1, self.numClass - (i - 1))
				self.heteroWeightsDict[circularDist] = 1.0 / len(range(2, self.numClass + 1))

	def generate_graph(self, n, m, m0, h):
		'''
		n: Target size for the generated network
		m: number of edges added with each new node
		m0: number of nodes to begin with
		h: homophily
		'''
		if n > 1 and np.sum(self.classRatio) == n:
			#print("Graph will be generated with size of each class exactly equal to the number specified in classRatio.")
			self.__colorlist = []
			for classID, classSize in enumerate(self.classRatio):
				self.__colorlist += [classID + 1] * int(classSize - m)
			random_state.shuffle(self.__colorlist)
			head_list = list(range(1, self.numClass + 1)) * m
			random_state.shuffle(head_list)
			self.__colorlist = head_list + self.__colorlist
			self.__coloriter = iter(self.__colorlist)
		else:
			self.__coloriter = None
		
		if m * self.numClass > m0:
			raise ValueError("Barabasi-Albert model requires m to be less or equal to m0")

		if m > n:
			raise ValueError("m > n should be satisfied")

		G = nx.Graph()

		for v in range(m0):
			next_color = self.get_color(self.classRatio)
			if v > 1:
				if h != 0 and h != 1:
					next_neighbor = v - 1
				else:
					next_n = self.get_neighbors(G, 1, next_color, h)
					if next_n is not None:
						next_neighbor = next_n[0]
					else:
						next_neighbor = None

			G.add_node(v, color=next_color)
			if v > 1 and next_neighbor is not None:
				G.add_edge(v, next_neighbor)
			
		for v in range(m0, n):
			if v % 1000 == 0:
				print("Generating graph... Now processing v = {}".format(v))
			col = self.get_color(self.classRatio)
			us = self.get_neighbors(G, m, col, h)

			G.add_node(v, color=col)
			assert us is not None
			for u in us:
				G.add_edge(v, u)

		assert len(list(nx.selfloop_edges(G))) == 0
		return G

	def generate_graph_contaminated(self, n, m, m0, h, contamination = 0.2):
		'''
		n: Target size for the generated network
		m: number of edges added with each new node
		m0: number of nodes to begin with
		h: homophily
		'''
		if n > 1 and np.sum(self.classRatio) == n:
			#print("Graph will be generated with size of each class exactly equal to the number specified in classRatio.")
			self.__colorlist = []
			for classID, classSize in enumerate(self.classRatio):
				self.__colorlist += [classID + 1] * int(classSize - m)
			random_state.shuffle(self.__colorlist)
			head_list = list(range(1, self.numClass + 1)) * m
			random_state.shuffle(head_list)
			self.__colorlist = head_list + self.__colorlist
			self.__coloriter = iter(self.__colorlist)
		else:
			self.__coloriter = None
		
		if m * self.numClass > m0:
			raise ValueError("Barabasi-Albert model requires m to be less or equal to m0")

		if m > n:
			raise ValueError("m > n should be satisfied")

		G = nx.Graph()

		for v in range(m0):
			next_color = self.get_color(self.classRatio)
			if v > 1:
				if h != 0 and h != 1:
					next_neighbor = v - 1
				else:
					next_n = self.get_neighbors(G, 1, next_color, h)
					if next_n is not None:
						next_neighbor = next_n[0]
					else:
						next_neighbor = None

			G.add_node(v, color=next_color)
			if v > 1 and next_neighbor is not None:
				G.add_edge(v, next_neighbor)
			
		for v in range(m0, n):
			if v % 1000 == 0:
				print("Generating graph... Now processing v = {}".format(v))
			col = self.get_color(self.classRatio)

			r = np.random.uniform()
			if r < contamination/2:
				changed_h = h + 0.25
			elif r < contamination:
				changed_h = h - 0.25
			else:
				changed_h = h
			us = self.get_neighbors(G, m, col, changed_h)

			G.add_node(v, color=col)
			assert us is not None
			for u in us:
				G.add_edge(v, u)

		assert len(list(nx.selfloop_edges(G))) == 0
		return G

	def __call__(self, n, m, m0, h):
		return self.generate_graph(n, m, m0, h)

	def save_graph(self, G:nx.Graph, savePath="{graphName}", graphName="{method}-n{numNode}-h{h}-c{numClass}", **kwargs):
		super().save_graph(G, savePath, graphName, method="mixhop", **kwargs)

	def save_y(self, G:nx.Graph, savePath="{graphName}", graphName="{method}-n{numNode}-h{h}-c{numClass}", **kwargs):
		super().save_y(G, savePath, graphName, method="mixhop", **kwargs)

In [None]:
def load_synthetic_data(number_of_graphs = 100, h_inlier=0, h_outlier=1, outlier_ratio=0.5, n_min = 50, n_max = 150, no_of_tags = 5, type1 = "mixhop", type2 = "mixhop"):
    print('generating data')
    g_list = []
    
    number_of_outliers = int(number_of_graphs*outlier_ratio)

    for i in range(number_of_graphs - number_of_outliers):
        
        n = np.random.randint(n_min, n_max)
        tag_counts = random_split_counts(n, no_of_tags)

        if type1 == "mixhop":
            g = MixhopGraphGenerator(tag_counts, heteroWeightsExponent=1.0)(n, 2, 10, h_inlier)
            tags = [g.nodes[v]['color'] for v in g.nodes]
        
        g_list.append(S2VGraph(g, 0, node_tags=tags))
    #draw_graph(g, "g1.jpg")
    
    for i in range(number_of_graphs - number_of_outliers, number_of_graphs):
        
        n = np.random.randint(n_min, n_max)
        tag_counts = random_split_counts(n, no_of_tags)

        if type2 == "mixhop":
            g = MixhopGraphGenerator(tag_counts, heteroWeightsExponent=1.0)(n, 2, 10, h_outlier)
            tags = [g.nodes[v]['color'] for v in g.nodes]
        
        g_list.append(S2VGraph(g, 1, node_tags=tags))
    #draw_graph(g, "g2.jpg")
    
    for g in g_list:
        edges = [list(pair) for pair in g.g.edges()]
        edges.extend([[i, j] for j, i in edges])

        g.edge_mat = torch.LongTensor(edges).transpose(0,1)

    if g.node_tags == None:
        print("no node tags provided, using degrees as tags")
        for g in g_list:
            g.node_tags = list(dict(g.g.degree).values())


    # Extracting unique tags and converting to one-hot features   
    tagset = set()
    for g in g_list:
        tagset = tagset.union(set(g.node_tags))

    tagset = list(tagset)
    tag2index = {tagset[i]:i for i in range(len(tagset))}

    for g in g_list:
        g.node_features = torch.zeros(len(g.node_tags), len(tagset))
        g.node_features[range(len(g.node_tags)), [tag2index[tag] for tag in g.node_tags]] = 1

    print('Maximum node tag: %d' % len(tagset))
    print("Number of graphs generated: %d" % (number_of_graphs))
    print("Number of outlier graphs generated: %d" % (number_of_outliers))

    return g_list, 2



In [None]:
def compute_gamma(embeddings, device=torch.device("cpu")):
    all_vertex_embeddings = torch.cat(embeddings, axis=0).detach().to(device)
    all_vertex_distances = torch.cdist(all_vertex_embeddings, all_vertex_embeddings)**2
    median_of_distances = torch.median(all_vertex_distances)
    if median_of_distances <= 1e-4:
        median_of_distances = 1e-4
    
    gamma = 1/median_of_distances

    return gamma

In [None]:
def compute_mmd_gram_matrix(X_embeddings, Y_embeddings=None, gamma=None, type="SMM", device=torch.device("cpu")):
    
    if not Y_embeddings:
        Y_embeddings = X_embeddings

    if gamma == None:
        gamma = compute_gamma(Y_embeddings)
    if gamma==0:
        raise ValueError("Gamma value appears to be 0")
    
    # pad with 0s and convert to 3d tensor. 
    X_padded = pad_sequence(X_embeddings, batch_first=True).to(device)
    Y_padded = pad_sequence(Y_embeddings, batch_first=True).to(device)

    # calculate mask to be able to exclude padded 0s later while computing mean
    X_ones = [torch.ones(emb.shape[0]) for emb in X_embeddings]
    Y_ones = [torch.ones(emb.shape[0]) for emb in Y_embeddings]
    X_ones_padded = pad_sequence(X_ones, batch_first=True).to(device)
    Y_ones_padded = pad_sequence(Y_ones, batch_first=True).to(device)
    mask = X_ones_padded[:,None,:,None]*Y_ones_padded[None,:,None,:]

    XY = torch.matmul(X_padded[:,None,:,:], torch.transpose(Y_padded[None,:,:,:], -1, -2))

    if type=="SMM":
        X_sq = torch.squeeze(torch.matmul(X_padded[:,:,None,:], X_padded[:,:,:,None]))
        Y_sq = torch.squeeze(torch.matmul(Y_padded[:,:,None,:], Y_padded[:,:,:,None]))
        
        K_XY = torch.exp(-gamma * (-2 * XY + X_sq[:,None,:,None] + Y_sq[None,:,None,:]))

        masked_means = torch.true_divide(torch.sum(K_XY*mask,(2,3)), torch.sum(mask,(2,3)))
    else:
        raise ValueError("This type is not supported (yet)")
    
    return masked_means

In [None]:
class MLP2layer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, bias=False):
        '''
            input_dim: dimensionality of input features
            hidden_dim: dimensionality of hidden units at ALL layers
            output_dim: number of classes for prediction
        '''
    
        super(MLP2layer, self).__init__()

        self.linear_in = nn.Linear(input_dim, hidden_dim, bias=bias)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.linear_out = nn.Linear(hidden_dim, output_dim, bias=bias)
    
    def forward(self, x):
        x = self.linear_in(x)
        x = self.batchnorm(x)
        x = F.relu(x)
        x = self.linear_out(x)
        return x

In [None]:
class GraphCNN_SVDD(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, learn_eps, neighbor_pooling_type, bias, device):
        '''
            num_layers: number of layers in the neural networks (INCLUDING the input layer)
            input_dim: dimensionality of input features
            hidden_dim: dimensionality of hidden units at ALL layers
            output_dim: number of classes for prediction
            learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether. 
            neighbor_pooling_type: how to aggregate neighbors (mean, average, or max)
            device: which device to use
        '''

        super(GraphCNN_SVDD, self).__init__()

        self.device = device
        self.num_layers = num_layers
        self.neighbor_pooling_type = neighbor_pooling_type
        self.learn_eps = learn_eps
        self.eps = nn.Parameter(torch.zeros(self.num_layers-1))

        ###List of MLPs
        self.mlps = torch.nn.ModuleList()

        ###List of batchnorms applied to the output of MLP (input of the final prediction linear layer)
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(self.num_layers-1):
            if layer == 0:
                self.mlps.append(MLP2layer(input_dim, hidden_dim, hidden_dim, bias=bias))
            else:
                self.mlps.append(MLP2layer(hidden_dim, hidden_dim, hidden_dim, bias=bias))

            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        #Linear function that maps the hidden representation at dofferemt layers into a prediction score
        #self.linears_prediction = torch.nn.ModuleList()
        #for layer in range(num_layers):
        #    if layer == 0:
        #        self.linears_prediction.append(nn.Linear(input_dim, output_dim))
        #    else:
        #        self.linears_prediction.append(nn.Linear(hidden_dim, output_dim))

        

    def __preprocess_neighbors_maxpool(self, batch_graph):
        ###create padded_neighbor_list in concatenated graph

        #compute the maximum number of neighbors within the graphs in the current minibatch
        max_deg = max([graph.max_neighbor for graph in batch_graph])

        padded_neighbor_list = []
        start_idx = [0]


        for i, graph in enumerate(batch_graph):
            start_idx.append(start_idx[i] + len(graph.g))
            padded_neighbors = []
            for j in range(len(graph.neighbors)):
                #add off-set values to the neighbor indices
                pad = [n + start_idx[i] for n in graph.neighbors[j]]
                #padding, dummy data is assumed to be stored in -1
                pad.extend([-1]*(max_deg - len(pad)))

                #Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
                if not self.learn_eps:
                    pad.append(j + start_idx[i])

                padded_neighbors.append(pad)
            padded_neighbor_list.extend(padded_neighbors)

        return torch.LongTensor(padded_neighbor_list)


    def __preprocess_neighbors_sumavepool(self, batch_graph):
        ###create block diagonal sparse matrix

        edge_mat_list = []
        start_idx = [0]
        for i, graph in enumerate(batch_graph):
            start_idx.append(start_idx[i] + len(graph.g))
            edge_mat_list.append(graph.edge_mat + start_idx[i])
        Adj_block_idx = torch.cat(edge_mat_list, 1)
        Adj_block_elem = torch.ones(Adj_block_idx.shape[1])

        #Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.

        if not self.learn_eps:
            num_node = start_idx[-1]
            self_loop_edge = torch.LongTensor([range(num_node), range(num_node)])
            elem = torch.ones(num_node)
            Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1)
            Adj_block_elem = torch.cat([Adj_block_elem, elem], 0)

        Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1],start_idx[-1]]))

        return Adj_block.to(self.device)


    def maxpool(self, h, padded_neighbor_list):
        ###Element-wise minimum will never affect max-pooling

        dummy = torch.min(h, dim = 0)[0]
        h_with_dummy = torch.cat([h, dummy.reshape((1, -1)).to(self.device)])
        pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim = 1)[0]
        return pooled_rep


    def next_layer_eps(self, h, layer, padded_neighbor_list = None, Adj_block = None):
        ###pooling neighboring nodes and center nodes separately by epsilon reweighting. 

        if self.neighbor_pooling_type == "max":
            ##If max pooling
            pooled = self.maxpool(h, padded_neighbor_list)
        else:
            #If sum or average pooling
            pooled = torch.spmm(Adj_block, h)
            if self.neighbor_pooling_type == "average":
                #If average pooling
                degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
                pooled = pooled/degree

        #Reweights the center node representation when aggregating it with its neighbors
        pooled = pooled + (1 + self.eps[layer])*h
        pooled_rep = self.mlps[layer](pooled)
        h = self.batch_norms[layer](pooled_rep)

        #non-linearity
        h = F.relu(h)
        return h


    def next_layer(self, h, layer, padded_neighbor_list = None, Adj_block = None):
        ###pooling neighboring nodes and center nodes altogether  
            
        if self.neighbor_pooling_type == "max":
            ##If max pooling
            pooled = self.maxpool(h, padded_neighbor_list)
        else:
            #If sum or average pooling
            pooled = torch.spmm(Adj_block, h)
            if self.neighbor_pooling_type == "average":
                #If average pooling
                degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
                pooled = pooled/degree

        #representation of neighboring and center nodes 
        pooled_rep = self.mlps[layer](pooled)

        h = self.batch_norms[layer](pooled_rep)

        #non-linearity
        h = F.relu(h)
        return h

    def forward(self, batch_graph, output_layer):
        X_concat = torch.cat([graph.node_features for graph in batch_graph], 0).to(self.device)
        
        if self.neighbor_pooling_type == "max":
            padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph)
        else:
            Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph)

        #list of hidden representation at each layer (including input)
        hidden_rep = [X_concat]
        h = X_concat

        for layer in range(self.num_layers-1):
            if self.neighbor_pooling_type == "max" and self.learn_eps:
                h = self.next_layer_eps(h, layer, padded_neighbor_list = padded_neighbor_list)
            elif not self.neighbor_pooling_type == "max" and self.learn_eps:
                h = self.next_layer_eps(h, layer, Adj_block = Adj_block)
            elif self.neighbor_pooling_type == "max" and not self.learn_eps:
                h = self.next_layer(h, layer, padded_neighbor_list = padded_neighbor_list)
            elif not self.neighbor_pooling_type == "max" and not self.learn_eps:
                h = self.next_layer(h, layer, Adj_block = Adj_block)

            hidden_rep.append(h)

        if output_layer=="all":
            hidden_rep = torch.cat(hidden_rep, axis=1)
        else:
            hidden_rep = hidden_rep[output_layer]
    
        index = 0
        embeddings = []
        for graph in batch_graph:
            embedding = hidden_rep[index:index+len(graph.g)]
            index += len(graph.g)
            embeddings.append(embedding)


        return embeddings

In [None]:
def train_fullbatch(args, model, train_graphs, optimizer, epoch, Z, center=None, linear_layer=None, linear_layer_optimizer=None):
    model.eval() # we don't need batch-norm to track mean/variance
    
    loss_accum = 0
    svdd_loss_accum = 0
    reg_loss_accum = 0
    total_iters = args.iters_per_epoch
    #pbar = tqdm(range(total_iters), unit='iter')
    #pbar.set_description('epoch: %d' % (epoch))

    #for pos in pbar:
    for _ in range(total_iters):

        Z_embeddings = model(Z, args.layer)
        gamma = compute_gamma(Z_embeddings, device=args.device).detach() # no backpropagation for gamma
        
        K_Z = compute_mmd_gram_matrix(Z_embeddings, gamma=gamma, device=args.device).to(args.device)
        eigenvalues, U_Z = torch.symeig(K_Z, eigenvectors=True)
        T = torch.matmul(U_Z,torch.diag(eigenvalues**-0.5))

        R_embeddings = model(train_graphs,args.layer)
        K_RZ = compute_mmd_gram_matrix(R_embeddings, Z_embeddings, gamma=gamma, device=args.device).to(args.device)
        F = torch.matmul(K_RZ, T)

        if center == None:
            center = torch.median(F, dim=0).values.detach() # no backpropagation for center
        
        dists = torch.sum((F - center)**2, dim=1).cpu()
        
        #dists = torch.sum(
        #        (torch.abs(F-center) <= args.delta)*((F-center)**2) +
        #                      (torch.abs(F-center) > args.delta)*2*(args.delta*torch.abs((F-center))-(args.delta**2)),
        #                      dim=1).cpu()
        
        ## Update hypersphere radius R on mini-batch distances
        #if epoch > args.warm_up_n_epochs:
        #    if args.train_only_inlier:
        #        args.radius = np.sqrt(np.max(dists.clone().data.cpu().numpy()))
        #    else:
        #        args.radius = np.sqrt(np.quantile(dists.clone().data.cpu().numpy(), 1 - args.nu))
        
        scores = torch.clamp(dists - (args.radius**2), min=0)
        svdd_loss = (1/args.nu)*torch.mean(scores)
        
        if args.regularizer == "variance":
            reg_loss = (1/(F.shape[0]-1))*torch.sum(torch.var(F,dim=0))
            
        elif args.regularizer == "classification":
            random_labels = torch.empty(len(train_graphs),args.no_of_random_labels).random_(2).to(args.device)
            raw_scores = linear_layer(F).to(args.device)

            criterion = nn.BCELoss()
            m = nn.Sigmoid()

            reg_loss = criterion(m(raw_scores), random_labels).cpu()

        else:
            raise ValueError("Unrecognized regularization type")

        loss = svdd_loss - args.reg_weight * reg_loss
        args.reg_weight = (args.alpha*args.reg_weight + (1-args.alpha)*args.beta*(svdd_loss/reg_loss)).detach()
            
        #backpropagate
        optimizer.zero_grad()
        if args.regularizer == "classification":
            linear_layer_optimizer.zero_grad()

        loss.backward()    

        optimizer.step()
        if args.regularizer == "classification":
            linear_layer_optimizer.step()

        loss_accum += loss.detach().cpu().numpy()
        svdd_loss_accum += svdd_loss.detach().cpu().numpy()
        reg_loss_accum += reg_loss.detach().cpu().numpy()

    average_loss = loss_accum/total_iters
    average_svdd_loss = svdd_loss_accum/total_iters
    average_reg_loss = reg_loss_accum/total_iters

    return average_loss, average_svdd_loss, average_reg_loss, center

In [None]:
def test(args, model, test_graphs, Z, center=None):
    model.eval()
    
    with torch.no_grad():
    
        Z_embeddings = model(Z, args.layer)
        gamma = compute_gamma(Z_embeddings, args.device)
        
        K_Z = compute_mmd_gram_matrix(Z_embeddings, gamma=gamma, device=args.device).to(args.device)
        eigenvalues, U_Z = torch.symeig(K_Z, eigenvectors=True)
        T = torch.matmul(U_Z,torch.diag(eigenvalues**-0.5))

        R_embeddings = model(test_graphs,args.layer)
        K_RZ = compute_mmd_gram_matrix(R_embeddings, Z_embeddings, gamma=gamma, device=args.device).to(args.device)
        F = torch.matmul(K_RZ, T)
        
        if center == None:
            center = torch.median(F, dim=0).values
        dists = torch.sum((F - center)**2, dim=1).cpu()

        labels = torch.LongTensor([graph.label for graph in test_graphs])
        
        score = average_precision_score(labels, dists)
        return score, dists

In [None]:
parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification')
parser.add_argument('--device', type=int, default=0,
                    help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=100,
                    help='input batch size for training (default: 100)')
parser.add_argument('--iters_per_epoch', type=int, default=1,
                    help='number of iterations per each epoch (default: 1)')
parser.add_argument('--epochs', type=int, default=500,
                    help='number of epochs to train (default: 500)')
parser.add_argument('--lr', type=float, default=0.01,
                    help='learning rate (default: 0.01)')
parser.add_argument('--weight_decay', type=float, default=0,
                    help='weight_decay constant (lambda), default=0.')


parser.add_argument('--num_layers', type=int, default=5,
                    help='number of layers INCLUDING the input one (default: 5)')
parser.add_argument('--hidden_dim', type=int, default=64,
                    help='number of hidden units (default: 64)')
parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"],
                    help='Pooling for over neighboring nodes: sum, average or max')
parser.add_argument('--dont_learn_eps', action="store_true",
                                    help='Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.')
parser.add_argument('--bias', action="store_true",
                                    help='Whether to use bias terms in the GNN.')
parser.add_argument('--degree_as_tag', action="store_true",
                    help='let the input node features be the degree of nodes (heuristics for unlabeled graph)')
parser.add_argument('--layer', type = str, default = "all",
                                    help='which hidden layer used as embedding')


parser.add_argument('--dataset', type = str, default = "mixhop", choices=["mixhop", "chem", "contaminated", "saved"],
                                    help='dataset used')
parser.add_argument('--no_of_graphs', type = int, default = 100,
                                    help='no of graphs generated')
parser.add_argument('--h_inlier', type=float, default=0.3,
                    help='inlier homophily (default: 0.3)')
parser.add_argument('--h_outlier', type=float, default=0.7,
                    help='outlier homophily (default: 0.7)')
parser.add_argument('--nu', type=float, default=0.05,
                    help='expected fraction of outliers (default: 0.05)')
parser.add_argument('--k_frac', type=float, default=0.4,
                    help='fraction of landmark points (default: 0.4)')


parser.add_argument('--radius', type=str, default="0",
                    help='hypersphere radius (default: 0)')
parser.add_argument('--train_only_inlier', action="store_true",
                                    help='Train only using inlier data')
#parser.add_argument('--delta', type=float, default=1e9,
#                                    help='threshold for huber loss')


parser.add_argument('--regularizer', type=str, default="variance", choices=["variance", "classification"], 
                    help='type of regularizer (default: variance)')
parser.add_argument('--reg_weight', type=float, default=0,
                                    help='weight for variance regularizer')
parser.add_argument('--alpha', type=float, default=0.9,
                    help='regularizer: speed of adaptivity')
parser.add_argument('--beta', type=float, default=0.5,
                    help='regularizer loss multiplier (ratio)')
#parser.add_argument('--warm_up_n_epochs', type=float, default=10,
#                    help='epochs before radius is updated (default: 10)')

parser.add_argument('--no_of_random_labels', type=int, default=10,
                    help='number of random labels (default: 10')


In [None]:
args = parser.parse_args("--train_only_inlier --layer=1 --dataset=mixhop --alpha=1.0 --beta=0.75".split())


torch.manual_seed(0)
np.random.seed(0)
filename = "beta_"
filename += str(args.beta)
filename += "_alpha_"
filename += str(args.alpha)
#if args.reg_weight_adaptive:
#    filename += "adaptive"
#else:
#    filename += str(args.reg_weight)
if args.bias:
    filename += "_biasallowed"
if args.train_only_inlier:
    filename += "_inlieronly"
filename += ".png"

args.device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

if args.radius != "dynamic":
    args.radius = float(args.radius)
    args.warm_up_n_epochs = args.epochs # never update radius dynamically
else:
    args.radius = 0

if args.layer != "all":
    args.layer = int(args.layer)

In [None]:
if args.dataset == "mixhop":
    graphs, num_classes = load_synthetic_data(number_of_graphs=args.no_of_graphs, h_inlier=args.h_inlier, h_outlier=args.h_outlier, outlier_ratio=args.nu)
    with open('graphs.pkl', 'wb') as f:
        pickle.dump((graphs, num_classes), f)
elif args.dataset == "contaminated":
    graphs, num_classes = load_synthetic_data_contaminated(number_of_graphs=args.no_of_graphs, outlier_ratio=args.nu)
elif args.dataset == "chem":
    graphs, num_classes = load_chem_data()
elif args.dataset == "saved":
    with open('graphs.pkl', 'rb') as f:
        graphs, num_classes = pickle.load(f)

In [None]:
if args.train_only_inlier:
    train_graphs, test_graphs = graphs[:int(args.no_of_graphs*(1-args.nu))], graphs
else:
    train_graphs, test_graphs = graphs, graphs

In [None]:
k = int(args.k_frac*args.no_of_graphs)
np.random.seed(0)
Z = np.random.permutation(graphs[:int(args.no_of_graphs*(1-args.nu))])[:k] # pick landmark set only from inliers
#print(Z)
no_of_node_features = graphs[0].node_features.shape[1]

In [None]:
model = GraphCNN_SVDD(args.num_layers, no_of_node_features, args.hidden_dim, num_classes, (not args.dont_learn_eps), args.neighbor_pooling_type, args.bias, args.device).to(args.device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

if args.regularizer == "classification":
    linear_layer = nn.Linear(k, args.no_of_random_labels).to(args.device)
    linear_layer_optimizer = optim.SGD(linear_layer.parameters(), lr=args.lr, weight_decay=args.weight_decay)
else:
    linear_layer = None
    linear_layer_optimizer = None

aps = []
#outlier_ratios = []

svdd_losses = []
reg_losses = []
total_losses = []

#PRE-TRAINING TEST
score, dists = test(args, model, test_graphs, Z)
print("Pre-Training AP Score: %f" % score)
aps.append(score)
dists = (dists**0.5).detach().numpy()


#outlier_ratio = sum(dists > args.radius)/len(dists)
#outlier_ratios.append(outlier_ratio)

distlist= []
distlist.append(dists)

no_epochs = args.epochs
train = train_fullbatch
center = None

for epoch in range(1, no_epochs + 1):

    # training
    loss, svdd_loss, reg_loss, center = train(args, model, train_graphs, optimizer, epoch, Z, center, linear_layer, linear_layer_optimizer)
    

    ## Calculate Weight Decay Loss
    # model_reg_loss = 0
    #for param in model.parameters():
    #    model_reg_loss += 0.5 * args.weight_decay * torch.sum(param ** 2)
    
    score, dists = test(args, model, test_graphs, Z, center)
    dists = (dists**0.5).detach().numpy()
    
    #outlier_ratio = sum(dists > args.radius)/len(dists)

    distlist.append(dists)
    
    print("Epoch %d" % epoch, end="\t")
    print("SVDD loss: %f" % (svdd_loss), end="\t")
    print("Regularizer loss: %f" % (reg_loss), end="\t")
    print("Avg Precision Score: %f" % score)
    #print("Outlier ratio: %f" % outlier_ratio)

    aps.append(score)
    #outlier_ratios.append(outlier_ratio)
    
    svdd_losses.append(svdd_loss)
    reg_losses.append(reg_loss)
    total_losses.append(loss)

In [None]:
gap = 10

intermittent_distlist = distlist[::gap]

xs = []
ys = []
xs2 = []
ys2 = []

for i, dists in enumerate(intermittent_distlist):
    for dist in dists[:95]:
        xs.append(gap*i)
        ys.append(dist)
    for dist in dists[95:]:
        xs2.append(gap*i)
        ys2.append(dist)

In [None]:
from google.colab import files

fig, axs = plt.subplots(2)
fig.suptitle("SVDD with Nystrom for fixed radius=%f" % args.radius)
fig.tight_layout(pad=2, h_pad=2, w_pad=2)

axs[0].set(ylabel='Average Precision', ylim=((0,1)))
axs[0].plot(list(range(0, no_epochs + 1)), aps)
axs[0].grid()

axs[1].set(xlabel='Epochs', ylabel='Distances')
axs[1].scatter(xs, ys, s=1, color='blue')
axs[1].scatter(xs2, ys2, s=1, color='red')
axs[1].grid()

plt.show()

fig.savefig(filename, dpi=1000)
#files.download(filename)

In [None]:
'''
fig2, axs2 = plt.subplots(3)
fig2.suptitle("Losses (regularizer = " + args.regularizer + ") for alpha=%f and beta=%f" % (args.alpha, args.beta))
fig2.tight_layout(pad=2, h_pad=2, w_pad=2)

axs2[0].set(ylabel='SVDD Loss')
axs2[0].plot(list(range(1, no_epochs + 1)), svdd_losses)
axs2[0].grid()

axs2[1].set(ylabel='Regularizer Loss')
axs2[1].plot(list(range(1, no_epochs + 1)), reg_losses)
axs2[1].grid()

axs2[2].set(ylabel='Total Loss')
axs2[2].plot(list(range(1, no_epochs + 1)), total_losses)
axs2[2].grid()


plt.show()

fig2.savefig("losses_" + filename, dpi=1000)
#files.download("losses_" + filename)
'''

In [None]:
def em(t, t_max, volume_support, s_unif, s_X, n_generated):
    EM_t = np.zeros(t.shape[0])
    n_samples = s_X.shape[0]
    s_X_unique = np.unique(s_X)
    EM_t[0] = 1.
    for u in s_X_unique:
        # if (s_unif >= u).sum() > n_generated / 1000:
        EM_t = np.maximum(EM_t, 1. / n_samples * (s_X < u).sum() -
                          t * (s_unif < u).sum() / n_generated
                          * volume_support)
    amax = np.argmax(EM_t <= t_max) + 1
    if amax == 1:
        print('\n failed to achieve t_max \n')
        amax = -1
    AUC = auc(t[:amax], EM_t[:amax])
    return AUC, EM_t, amax


def mv(axis_alpha, volume_support, s_unif, s_X, n_generated):
    n_samples = s_X.shape[0]
    s_X_argsort = s_X.argsort()
    mass = 0
    cpt = 0
    u = s_X[s_X_argsort[0]]
    mv = np.zeros(axis_alpha.shape[0])
    for i in range(axis_alpha.shape[0]):
        # pdb.set_trace()
        while mass < axis_alpha[i]:
            cpt += 1
            u = s_X[s_X_argsort[cpt-1]]
            mass = 1. / n_samples * cpt  # sum(s_X > u)
        mv[i] = float((s_unif <= u).sum()) / n_generated * volume_support
    return auc(axis_alpha, mv), mv

In [None]:
def compute_em_mv_aucs(X):

    n_generated = 1000
    alpha_min = 0.001
    alpha_max = 0.999
    t_max = 0.9

    n_samples = len(X)
    n_features = 1

    lim_inf = X.min()
    lim_sup = X.max()
    volume_support = lim_sup - lim_inf
    t = np.arange(0, 100 / volume_support, 0.01 / volume_support)
    axis_alpha = np.arange(alpha_min, alpha_max, 0.001)
    unif = np.random.uniform(lim_inf, lim_sup, size=(n_generated, n_features))

    auc_em, _, _ = em(t, t_max, volume_support, unif, X, n_generated)

    auc_mv, _ = mv(axis_alpha, volume_support,
                                    unif, X, n_generated)
    return auc_em, auc_mv

In [None]:
def customscore(X):
    inliers = np.sort(X)[:int(0.95*len(X))]
    inlier_median = np.median(inliers)
    full_median = np.median(X)

    score = np.mean((X - full_median)**2) / np.mean((inliers - inlier_median)**2)
    return score

In [None]:
def rsquared(X):
    outliers = np.sort(X)[int(0.95*len(X)):]
    inliers = np.sort(X)[:int(0.95*len(X))]
    inlier_median = np.median(inliers)
    outlier_median = np.median(outliers)
    full_median = np.median(X)

    score = (np.sum((inliers - inlier_median)**2) + np.sum((outliers - outlier_median)**2))/np.sum((X - full_median)**2)
    return 1-score

In [None]:
def xb(X):
    outliers = np.sort(X)[int(0.95*len(X)):]
    inliers = np.sort(X)[:int(0.95*len(X))]
    inlier_median = np.median(inliers)
    outlier_median = np.median(outliers)
    
    score = (np.sum((inliers - inlier_median)**2) + np.sum((outliers - outlier_median)**2))/ (len(X)*(inlier_median - outlier_median)**2)
    return score

In [None]:
rsquareds = []
customs = []
xbs = []
ems = []
mvs = []
for dists in distlist:
    em_auc, mv_auc = compute_em_mv_aucs(dists)
    ems.append(em_auc)
    mvs.append(mv_auc)
    customs.append(customscore(dists))
    rsquareds.append(rsquared(dists))
    xbs.append(xb(dists))


In [None]:
fig3, axs3 = plt.subplots(4)
axs3[0].set(ylabel='Average Precision', ylim=((0,1)))
axs3[0].plot(list(range(0, no_epochs + 1)), aps)
axs3[0].grid()

axs3[1].set(xlabel='Epochs', ylabel='Distances')
axs3[1].scatter(xs, ys, s=1, color='blue')
axs3[1].scatter(xs2, ys2, s=1, color='red')
axs3[1].grid()

#axs2[4].set(ylabel='custom')
#axs2[4].plot(list(range(len(customs))), customs)
#axs2[4].grid()

#axs3[2].set(ylabel='R Squared')
#axs3[2].plot(list(range(len(rsquareds))), rsquareds)
#axs3[2].grid()

axs3[2].set(ylabel='Total Loss')
axs3[2].plot(list(range(len(total_losses))), total_losses)
axs3[2].grid()

axs3[3].set(ylabel='XB')
axs3[3].plot(list(range(len(xbs))), xbs)
axs3[3].grid()

plt.show()


fig3.savefig("cluster"+filename, dpi=1000)
files.download("cluster"+filename)