Skip to content

Commit

Permalink
Variational/weibull (#493)
Browse files Browse the repository at this point in the history
* added Weibull distribution

* Update distributions.py

* added test for Weibull distribution

* Update CHANGELOG.md

* Update distributions.py
  • Loading branch information
Ddaniela13 authored and ethanwharris committed Jan 24, 2019
1 parent cbf42aa commit f8e4353
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a var metric and decorator which can be used to calculate the variance of a metric
- Added an unbiased flag to the std and var metrics to optionally not apply Bessel's correction (consistent with torch.std / torch.var)
- Added support for rounding 1D lists to the Tqdm callback
- Added SimpleWeibull distribution
### Changed
- Changed the default behaviour of the std metric to compute the sample std, in line with torch.std
- Tqdm precision argument now rounds to decimal places rather than significant figures
Expand Down
38 changes: 35 additions & 3 deletions tests/variational/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

import torch

from torchbearer.variational import SimpleDistribution, SimpleNormal, SimpleUniform, SimpleExponential

from torchbearer.variational import SimpleDistribution, SimpleNormal, SimpleUniform, SimpleExponential, SimpleWeibull

class TestEmptyMethods(unittest.TestCase):
def test_methods(self):
Expand All @@ -25,7 +24,6 @@ def test_methods(self):
self.assertRaises(NotImplementedError, lambda: dist.rsample())
self.assertRaises(NotImplementedError, lambda: dist.log_prob(1))


class TestSimpleNormal(unittest.TestCase):
@patch('torchbearer.variational.distributions.torch.normal')
def test_rsample_tensor(self, normal):
Expand Down Expand Up @@ -140,3 +138,37 @@ def test_log_prob_number(self):
dist = SimpleExponential(math.log(0.5))

self.assertTrue(((dist.log_prob(torch.ones(2, 2)) + 1.1931).abs() < 0.0001).all())

class TestSimpleWeibull(unittest.TestCase):
@patch('torchbearer.variational.distributions.torch.rand')
def test_rsample_tensor(self, rand):
l = torch.ones(2, 2)
k = torch.ones(2, 2)

dist = SimpleWeibull(l, k)

rand.side_effect = lambda shape, dtype, device: torch.ones(shape) / 2
self.assertTrue(((dist.rsample(sample_shape=torch.Size([2])) - 0.6931).abs() < 0.0001).all())

@patch('torchbearer.variational.distributions.torch.rand')
def test_rsample_number(self, rand):
dist = SimpleWeibull(1, 1)

rand.side_effect = lambda shape, dtype, device: torch.ones(shape) / 2
self.assertTrue(((dist.rsample(sample_shape=torch.Size([2])) - 0.6931).abs() < 0.0001).all())

def test_log_prob_tensor(self):
l = torch.ones(2, 2)
k = torch.ones(2, 2)

dist = SimpleWeibull(l, k)
self.assertTrue((dist.log_prob(torch.ones(2, 2)) < 0.0001).all())
self.assertTrue((dist.log_prob(torch.ones(2, 2) - 2) == float('-inf')).all())

def test_log_prob_number(self):
dist = SimpleWeibull(1, 1)

self.assertTrue((dist.log_prob(1) < 0.0001).all())
self.assertTrue((dist.log_prob(-1) == float('-inf')).all())


73 changes: 55 additions & 18 deletions torchbearer/variational/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
import torch
from torch.distributions import Distribution
from torch.distributions.utils import broadcast_all

from torchbearer import cite

steve = """
@article{squires2019a,
title={A Variational Autoencoder for Probabilistic Non-Negative Matrix Factorisation},
author={Steven Squires and Adam Prugel-Bennett and Mahesan Niranjan},
year={2019}
}
"""

class SimpleDistribution(Distribution):
"""Abstract base class for a simple distribution which only implements rsample and log_prob. If the log_prob
Expand Down Expand Up @@ -60,10 +68,9 @@ def rsample(self, sample_shape=torch.Size()):
Returns a reparameterized sample or batch of reparameterized samples if the distribution parameters are batched.
"""
raise NotImplementedError

def log_prob(self, value):
"""Returns the log of the probability density/mass function evaluated at `value`.
Args:
value (torch.Tensor, Number): Value at which to evaluate log probabilty
"""
Expand All @@ -74,11 +81,11 @@ class SimpleNormal(SimpleDistribution):
"""The SimpleNormal class is a :class:`SimpleDistribution` which implements a straight forward Normal / Gaussian
distribution. This performs significantly fewer checks than `torch.distributions.Normal`, but should be sufficient
for the purpose of implementing a VAE.
Args:
mu (torch.Tensor, Number): The mean of the distribution, numbers will be cast to tensors
logvar (torch.Tensor, Number): The log variance of the distribution, numbers will be cast to tensors
"""

def __init__(self, mu, logvar):
self.mu, self.logvar = broadcast_all(mu, logvar)
if isinstance(mu, Number) and isinstance(logvar, Number):
Expand All @@ -89,10 +96,8 @@ def __init__(self, mu, logvar):

def rsample(self, sample_shape=torch.Size()):
"""Simple rsample for a Normal distribution.
Args:
sample_shape (torch.Size, tuple): Shape of the sample (per mean / variance given)
Returns:
A reparameterized sample with gradient with respect to the distribution parameters
"""
Expand All @@ -105,10 +110,8 @@ def rsample(self, sample_shape=torch.Size()):
def log_prob(self, value):
"""Calculates the log probability that the given value was drawn from this distribution. Since the density of a
Gaussian is differentiable, this function is differentiable.
Args:
value (torch.Tensor, Number): The sampled value
Returns:
The log probability that the given value was drawn from this distribution
"""
Expand All @@ -120,11 +123,11 @@ class SimpleUniform(SimpleDistribution):
"""The SimpleUniform class is a :class:`SimpleDistribution` which implements a straight forward Uniform distribution
in the interval ``[low, high)``. This performs significantly fewer checks than `torch.distributions.Uniform`, but
should be sufficient for the purpose of implementing a VAE.
Args:
low (torch.Tensor, Number): The lower range of the distribution (inclusive), numbers will be cast to tensors
high (torch.Tensor, Number): The upper range of the distribution (exclusive), numbers will be cast to tensors
"""

def __init__(self, low, high):
super().__init__()
self.low, self.high = broadcast_all(low, high)
Expand All @@ -136,10 +139,8 @@ def __init__(self, low, high):

def rsample(self, sample_shape=torch.Size()):
"""Simple rsample for a Uniform distribution.
Args:
sample_shape (torch.Size, tuple): Shape of the sample (per low / high given)
Returns:
A reparameterized sample with gradient with respect to the distribution parameters
"""
Expand All @@ -151,10 +152,8 @@ def log_prob(self, value):
"""Calculates the log probability that the given value was drawn from this distribution. Since this distribution
is uniform, the log probability is zero for all values in the range ``[low, high)`` and -inf elsewhere. This
function is therefore non-differentiable.
Args:
value (torch.Tensor, Number): The sampled value
Returns:
The log probability that the given value was drawn from this distribution
"""
Expand All @@ -169,10 +168,10 @@ class SimpleExponential(SimpleDistribution):
distribution with the given lograte. This performs significantly fewer checks than `torch.distributions.Exponential`
, but should be sufficient for the purpose of implementing a VAE. By using a lograte, the log_prob can be computed
in a stable fashion, without taking a logarithm.
Args:
lograte (torch.Tensor, Number): The natural log of the rate of the distribution, numbers will be cast to tensors
"""

def __init__(self, lograte):
super().__init__()
self.lograte, = broadcast_all(lograte)
Expand All @@ -181,10 +180,8 @@ def __init__(self, lograte):

def rsample(self, sample_shape=torch.Size()):
"""Simple rsample for an Exponential distribution.
Args:
sample_shape (torch.Size, tuple): Shape of the sample (per lograte given)
Returns:
A reparameterized sample with gradient with respect to the distribution parameters
"""
Expand All @@ -194,11 +191,51 @@ def rsample(self, sample_shape=torch.Size()):
def log_prob(self, value):
"""Calculates the log probability that the given value was drawn from this distribution. The log_prob for this
distribution is fully differentiable and has stable gradient since we use the lograte here.
Args:
value (torch.Tensor, Number): The sampled value
Returns:
The log probability that the given value was drawn from this distribution
"""
return self.lograte - self.lograte.exp() * value

@cite(steve)
class SimpleWeibull(SimpleDistribution):
"""The SimpleWeibull class is a :class:`SimpleDistribution` which implements a straight forward Weibull
distribution. This performs significantly fewer checks than `torch.distributions.Weibull`, but should be sufficient
for the purpose of implementing a VAE.
Args:
l (torch.Tensor, Number): The scale parameter of the distribution, numbers will be cast to tensors
k (torch.Tensor, Number): The shape parameter of the distribution, numbers will be cast to tensors
"""

def __init__(self, l, k):
self.l, self.k = broadcast_all(l, k)
self.const=1e-8
if isinstance(k, Number) and isinstance(l, Number):
batch_shape = torch.Size()
else:
batch_shape = self.k.size()
super().__init__(batch_shape=batch_shape)

def rsample(self, sample_shape=torch.Size()):
"""Simple rsample for a Weibull distribution.
Args:
sample_shape (torch.Size, tuple): Shape of the sample (per k / lambda given)
Returns:
A reparameterized sample with gradient with respect to the distribution parameters
"""
shape = self._extended_shape(sample_shape)
eps = torch.rand(shape, dtype=self.k.dtype, device=self.k.device)
return self.l * torch.pow((-torch.log(eps)), (1/self.k))

def log_prob(self, value):
"""Calculates the log probability that the given value was drawn from this distribution. This function is differentiable
and its log probability is -inf for values less than 0.
Args:
value (torch.Tensor, Number): The sampled value
Returns:
The log probability that the given value was drawn from this distribution
"""
value = value if torch.is_tensor(value) else torch.tensor(value, dtype=torch.get_default_dtype())
lb=value.ge(torch.zeros(value.shape, dtype=self.k.dtype, device=self.k.device)).float()
return torch.log(lb) + torch.log(self.k/self.l) + (self.k - torch.ones(self.k.shape, dtype=self.k.dtype, device=self.k.device))*torch.log((lb*value+self.const)/self.l) - torch.pow(value/self.l, self.k)

0 comments on commit f8e4353

Please sign in to comment.