Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend normalization to other p norms #225

Merged
merged 14 commits into from
Dec 22, 2020
57 changes: 47 additions & 10 deletions src/pykeen/regularizers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-

"""Regularization in PyKEEN."""

import functools
import math
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Collection, Iterable, Mapping, Optional, Type, Union

Expand Down Expand Up @@ -103,6 +104,50 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102
return torch.zeros(1, dtype=x.dtype, device=x.device)


@functools.lru_cache(maxsize=1)
mberr marked this conversation as resolved.
Show resolved Hide resolved
def _get_expected_norm(
p: Union[int, float, str],
d: int,
) -> float:
r"""
Compute the expected value of the L_p norm.

.. math ::
E[\|x\|_p] = d^{1/p} E[|x_1|^p]^{1/p}

under the assumption that :math:`x_i \sim N(0, 1)`, i.e.

.. math ::
E[|x_1|^p] = 2^{p/2} \cdot \Gamma(\frac{p+1}{2} \cdot \pi^{-1/2}

:param p:
The parameter p of the norm.
:param d:
The dimension of the vector.

:return:
The expected value.

.. seealso ::
https://math.stackexchange.com/questions/229033/lp-norm-of-multivariate-standard-normal-random-variable
https://www.wolframalpha.com/input/?i=expected+value+of+%7Cx%7C%5Ep
"""
if isinstance(p, str):
p = float(p)
if math.isinf(p) and p > 0: # max norm
# TODO: this only works for x ~ N(0, 1), but not for |x|
raise NotImplementedError("Normalization for inf norm is not implemented")
# cf. https://en.wikipedia.org/wiki/Generalized_extreme_value_distribution
# mean = scipy.stats.norm.ppf(1 - 1/d)
# scale = scipy.stats.norm.ppf(1 - 1/d * 1/math.e) - mean
# return scipy.stats.gumbel_r.mean(loc=mean, scale=scale)
elif math.isfinite(p):
exp_abs_norm_p = math.pow(2, p / 2) * math.gamma((p + 1) / 2) / math.sqrt(math.pi)
return math.pow(exp_abs_norm_p * d, 1 / p)
else:
raise NotImplementedError(f"{p} norm not implemented")


class LpRegularizer(Regularizer):
"""A simple L_p norm based regularizer."""

Expand Down Expand Up @@ -136,15 +181,7 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102
value = x.norm(p=self.p, dim=self.dim).mean()
if not self.normalize:
return value
dim = torch.as_tensor(x.shape[-1], dtype=torch.float, device=x.device)
if self.p == 1:
# expected value of |x|_1 = d*E[x_i] for x_i i.i.d.
return value / dim
if self.p == 2:
# expected value of |x|_2 when x_i are normally distributed
# cf. https://arxiv.org/pdf/1012.0621.pdf chapter 3.1
return value / dim.sqrt()
raise NotImplementedError(f'Lp regularization not implemented for p={self.p}')
return value / _get_expected_norm(p=self.p, d=x.shape[-1])


class PowerSumRegularizer(Regularizer):
Expand Down
20 changes: 19 additions & 1 deletion tests/test_regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import unittest
from typing import Any, ClassVar, Dict, Optional, Type

import pytest
import torch
from torch.nn import functional

from pykeen.datasets import Nations
from pykeen.models import ConvKB, RESCAL, TransH
from pykeen.regularizers import (
CombinedRegularizer, LpRegularizer, NoRegularizer, PowerSumRegularizer, Regularizer,
TransHRegularizer,
TransHRegularizer, _get_expected_norm,
)
from pykeen.triples import TriplesFactory
from pykeen.typing import MappedTriples
Expand Down Expand Up @@ -161,6 +162,23 @@ class NormedL2RegularizerTest(_LpRegularizerTest, unittest.TestCase):

regularizer_kwargs = {'normalize': True, 'p': 2}

@pytest.mark.slow
mberr marked this conversation as resolved.
Show resolved Hide resolved
def test_expected_norm(self):
"""Numerically check expected norm."""
n = 100
for p in (1, 2, 3):
for d in (2, 8, 64):
e_norm = _get_expected_norm(p=p, d=d)
norm = torch.randn(n, d).norm(p=p, dim=-1).numpy()
norm_mean = norm.mean()
norm_std = norm.std()
# check if within 0.5 std of observed
assert (abs(norm_mean - e_norm) / norm_std) < 0.5

# test error is raised
with pytest.raises(NotImplementedError):
_get_expected_norm(p=float('inf'), d=d)


class CombinedRegularizerTest(_RegularizerTestCase, unittest.TestCase):
"""Test the combined regularizer."""
Expand Down