Skip to content
Permalink
Browse files

is_undirected flag for eigenvalue transform

  • Loading branch information...
rusty1s committed Aug 13, 2019
1 parent 21b693c commit eb3bd025c9b930d8a70f357e2dc00ef6f8cb93a2
@@ -12,17 +12,17 @@ def test_laplacian_lambda_max():
edge_attr = torch.tensor([1, 1, 2, 2], dtype=torch.float)

data = Data(edge_index=edge_index, edge_attr=edge_attr)
out = LaplacianLambdaMax(normalization=None)(data)
out = LaplacianLambdaMax(normalization=None, is_undirected=True)(data)
assert len(out) == 3
assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(4.732049))

data = Data(edge_index=edge_index, edge_attr=edge_attr)
out = LaplacianLambdaMax(normalization='sym')(data)
out = LaplacianLambdaMax(normalization='sym', is_undirected=True)(data)
assert len(out) == 3
assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(2.0))

data = Data(edge_index=edge_index, edge_attr=edge_attr)
out = LaplacianLambdaMax(normalization='rw')(data)
out = LaplacianLambdaMax(normalization='rw', is_undirected=True)(data)
assert len(out) == 3
assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(2.0))

@@ -19,11 +19,15 @@ class LaplacianLambdaMax(object):
3. :obj:`"rw"`: Random-walk normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
is_undirected (bool, optional): If set to :obj:`True`, this transform
expects undirected graphs as input, and can hence speed up the
computation of the largest eigenvalue. (default: :obj:`False`)
"""

def __init__(self, normalization=None):
def __init__(self, normalization=None, is_undirected=False):
assert normalization in [None, 'sym', 'rw'], 'Invalid normalization'
self.normalization = normalization
self.is_undirected = is_undirected

def __call__(self, data):
edge_weight = data.edge_attr
@@ -36,7 +40,10 @@ def __call__(self, data):

L = to_scipy_sparse_matrix(edge_index, edge_weight, data.num_nodes)

eig_fn = eigsh if self.normalization == 'sym' else eigs
eig_fn = eigs
if self.is_undirected and self.normalization != 'rw':
eig_fn = eigsh

lambda_max = eig_fn(L, k=1, which='LM', return_eigenvectors=False)
data.lambda_max = float(lambda_max.real)

0 comments on commit eb3bd02

Please sign in to comment.
You can’t perform that action at this time.