In [1]:
import numpy as np
from numbers import Number
import torch
from torch.distributions import constraints, Gamma
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
from scipy import stats

## Implement generalized normal distribution

In [2]:
class GeneralizedNormal(Distribution):
    r"""
    Creates a Generalized Normal distribution parameterized by :attr:`loc`, :attr:`scale`, and :attr:`beta`.

    Example::

        >>> m = GeneralizedNormal(torch.tensor([0.0]), torch.tensor([1.0]), torch.tensor(0.5))
        >>> m.sample()  # GeneralizedNormal distributed with loc=0, scale=1, beta=0.5
        tensor([ 0.1337])

    Args:
        loc (float or Tensor): mean of the distribution
        scale (float or Tensor): scale of the distribution
        beta (float or Tensor): shape parameter of the distribution
    """
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive, 'beta': constraints.positive}
    support = constraints.real
    has_rsample = True

    @property
    def mean(self):
        return self.loc

    @property
    def variance(self):
        return (self.scale.pow(2) * torch.lgamma(3/self.beta).exp()) / torch.lgamma(1/self.beta).exp()

    @property
    def stddev(self):
        return self.variance()**0.5

    def __init__(self, loc, scale, beta, validate_args=None):
        self.loc, self.scale = broadcast_all(loc, scale)
        (self.beta,) = broadcast_all(beta)
        self.scipy_dist = stats.gennorm(loc=self.loc.detach().numpy(),
                            scale=self.scale.detach().numpy(),
                            beta=self.beta.detach().numpy())
        if isinstance(loc, Number) and isinstance(scale, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.loc.size()
        super(GeneralizedNormal, self).__init__(batch_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(GeneralizedNormal, _instance)
        batch_shape = torch.Size(batch_shape)
        new.loc = self.loc.expand(batch_shape)
        new.scale = self.scale.expand(batch_shape)
        super(GeneralizedNormal, new).__init__(batch_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new


    def rsample(self, sample_shape=torch.Size()):
        return torch.tensor(self.scipy_dist.rvs(list(sample_shape))).to(self.loc)


    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        return (-torch.log(2 * self.scale) - torch.lgamma(1/self.beta) + torch.log(self.beta)
                - torch.pow((torch.abs(value - self.loc) / self.scale), self.beta))


    def cdf(self, value):
        if isinstance(value, torch.Tensor):
            value = value.numpy()
        return torch.tensor(self.scipy_dist.cdf(value)).to(self.loc)


    def icdf(self, value):
        raise NotImplementedError


    def entropy(self):
        return (1/self.beta) - torch.log(self.beta) + torch.log(2*self.scale) + torch.lgamma(1/self.beta)

In [3]:
gennorm = GeneralizedNormal(0.,1.,0.5)

In [4]:
gennorm.log_prob(1.)

tensor(-2.3863)

In [5]:
gennorm.rsample([2,2])

tensor([[-0.0510, -0.4189],
        [ 0.6363,  2.6203]])

In [6]:
gennorm.entropy()

tensor(3.3863)

In [7]:
gennorm.cdf(1.)

tensor(0.6321)

In [8]:
param = torch.nn.Parameter(torch.tensor([1.]))

In [9]:
param

Parameter containing:
tensor([1.], requires_grad=True)

In [10]:
param.grad

In [11]:
log_prob = gennorm.log_prob(param)

In [12]:
log_prob.backward()

In [13]:
param.grad

tensor([-0.5000])

In [14]:
gennorm.sample()

tensor(-0.0859)

## Test implementation

In [15]:
from bnn_priors.prior.distributions import GeneralizedNormal

In [16]:
gennorm = GeneralizedNormal(loc=0., scale=1., beta=0.5)

In [17]:
gennorm.sample()

tensor(0.7658)

In [18]:
gennorm.log_prob(1.)

tensor(-2.3863)

In [19]:
gennorm.entropy()

tensor(3.3863)

In [20]:
gennorm.cdf(1.)

tensor(0.6321)

In [21]:
from bnn_priors.prior import GenNorm

In [22]:
prior_gn = GenNorm(shape=[2,2], loc=0., scale=1., beta=0.5)

In [23]:
prior_gn.p

Parameter containing:
tensor([[ 0.1803, -0.6137],
        [ 5.2962, -2.9347]], requires_grad=True)

In [24]:
prior_gn.log_prob()

tensor(-10.7676, grad_fn=<SumBackward0>)

In [25]:
prior_gn.p.grad

In [26]:
lp = prior_gn.log_prob()
lp.backward()

In [27]:
prior_gn.p.grad

tensor([[-1.1775,  0.6383],
        [-0.2173,  0.2919]])

In [28]:
from bnn_priors.prior import GenNormUniform

In [29]:
prior_gnu = GenNormUniform([2,2], 0., 1., beta=0.5)

In [30]:
prior_gnu.p

Parameter containing:
tensor([[ 2.7537, -1.6523],
        [-3.3908, 16.9701]], requires_grad=True)

In [31]:
prior_gnu.beta.p

Parameter containing:
tensor(0., requires_grad=True)

In [32]:
prior_gnu.beta()

tensor(0.5000, grad_fn=<AddBackward0>)

In [33]:
prior_gnu.log_prob()

tensor(-14.4509, grad_fn=<SumBackward0>)

In [34]:
prior_gnu.p.grad

In [35]:
prior_gnu.beta.p.grad

In [36]:
lp = prior_gnu.log_prob()
lp.backward()

In [37]:
prior_gnu.p.grad

tensor([[-0.3013,  0.3890],
        [ 0.2715, -0.1214]])

In [38]:
prior_gnu.beta.p.grad

tensor(-0.5882)