Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend normalization to other p norms #225

Merged
merged 14 commits into from
Dec 22, 2020
48 changes: 38 additions & 10 deletions src/pykeen/regularizers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-

"""Regularization in PyKEEN."""

import functools
import math
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Collection, Iterable, Mapping, Optional, Type, Union

Expand Down Expand Up @@ -103,6 +104,41 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102
return torch.zeros(1, dtype=x.dtype, device=x.device)


@functools.lru_cache(maxsize=1)
mberr marked this conversation as resolved.
Show resolved Hide resolved
def _get_expected_norm(
p: Union[int, float, str],
d: int,
) -> float:
r"""
Compute the expected value of the L_p norm.

.. math ::
E[\|x\|_p] = d^{1/p} E[|x_1|^p]^{1/p}

under the assumption that :math:`x_i \sim N(0, 1)`, i.e.

.. math ::
E[|x_1|^p] = 2^{p/2} \cdot \Gamma(\frac{p+1}{2} \cdot \pi^{-1/2}

:param p:
The parameter p of the norm.
:param d:
The dimension of the vector.

:return:
The expected value.

.. seealso ::
https://math.stackexchange.com/questions/229033/lp-norm-of-multivariate-standard-normal-random-variable
https://www.wolframalpha.com/input/?i=expected+value+of+%7Cx%7C%5Ep
"""
if isinstance(p, str) or not math.isfinite(p):
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Use https://en.wikipedia.org/wiki/Gumbel_distribution for p = +inf
raise NotImplementedError(f"{p} norm not implemented")
exp_abs_norm_p = math.pow(2, p / 2) * math.gamma((p + 1) / 2) / math.sqrt(math.pi)
return math.pow(exp_abs_norm_p * d, 1 / p)


class LpRegularizer(Regularizer):
"""A simple L_p norm based regularizer."""

Expand Down Expand Up @@ -136,15 +172,7 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102
value = x.norm(p=self.p, dim=self.dim).mean()
if not self.normalize:
return value
dim = torch.as_tensor(x.shape[-1], dtype=torch.float, device=x.device)
if self.p == 1:
# expected value of |x|_1 = d*E[x_i] for x_i i.i.d.
return value / dim
if self.p == 2:
# expected value of |x|_2 when x_i are normally distributed
# cf. https://arxiv.org/pdf/1012.0621.pdf chapter 3.1
return value / dim.sqrt()
raise NotImplementedError(f'Lp regularization not implemented for p={self.p}')
return value / _get_expected_norm(p=self.p, d=x.shape[-1])


class PowerSumRegularizer(Regularizer):
Expand Down
20 changes: 19 additions & 1 deletion tests/test_regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import unittest
from typing import Any, ClassVar, Dict, Optional, Type

import pytest
import torch
from torch.nn import functional

from pykeen.datasets import Nations
from pykeen.models import ConvKB, RESCAL, TransH
from pykeen.regularizers import (
CombinedRegularizer, LpRegularizer, NoRegularizer, PowerSumRegularizer, Regularizer,
TransHRegularizer,
TransHRegularizer, _get_expected_norm,
)
from pykeen.triples import TriplesFactory
from pykeen.typing import MappedTriples
Expand Down Expand Up @@ -161,6 +162,23 @@ class NormedL2RegularizerTest(_LpRegularizerTest, unittest.TestCase):

regularizer_kwargs = {'normalize': True, 'p': 2}

@pytest.mark.slow
mberr marked this conversation as resolved.
Show resolved Hide resolved
def test_expected_norm(self):
"""Numerically check expected norm."""
n = 100
for p in (1, 2, 3):
for d in (2, 8, 64):
e_norm = _get_expected_norm(p=p, d=d)
norm = torch.randn(n, d).norm(p=p, dim=-1).numpy()
norm_mean = norm.mean()
norm_std = norm.std()
# check if within 0.5 std of observed
assert (abs(norm_mean - e_norm) / norm_std) < 0.5

# test error is raised
with pytest.raises(NotImplementedError):
_get_expected_norm(p=float('inf'), d=d)


class CombinedRegularizerTest(_RegularizerTestCase, unittest.TestCase):
"""Test the combined regularizer."""
Expand Down