Skip to content

Commit

Permalink
Variational/weibull (#500)
Browse files Browse the repository at this point in the history
* added Weibull distribution

* Update distributions.py

* added test for Weibull distribution

* Update CHANGELOG.md

* Update distributions.py

* added SimpleWeibullSimpleWeibullKL

* Update CHANGELOG.md

* Fix breaking test
  • Loading branch information
Ddaniela13 authored and ethanwharris committed Jan 29, 2019
1 parent 78b939c commit 1a0bcf0
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for rounding 1D lists to the Tqdm callback
- Added SimpleWeibull distribution
- Added support for Python 2.7
- Added SimpleWeibullSimpleWeibullKL
### Changed
- Changed the default behaviour of the std metric to compute the sample std, in line with torch.std
- Tqdm precision argument now rounds to decimal places rather than significant figures
Expand Down
10 changes: 9 additions & 1 deletion tests/variational/test_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

import torchbearer
from torchbearer.variational import DivergenceBase, SimpleNormalUnitNormalKL, SimpleNormalSimpleNormalKL, SimpleNormal
from torchbearer.variational import DivergenceBase, SimpleNormalUnitNormalKL, SimpleNormalSimpleNormalKL, SimpleNormal, SimpleWeibull, SimpleWeibullSimpleWeibullKL

key = torchbearer.state_key('divergence_test')

Expand Down Expand Up @@ -156,3 +156,11 @@ def test_divergence(self):
input = SimpleNormal(torch.zeros(2, 2), torch.ones(2, 2) * -1.3863)
target = SimpleNormal(torch.ones(2, 2), torch.ones(2, 2) * 1.3863)
self.assertTrue(((callback.compute(input, target) - 1.0425).abs() < 0.0001).all())

class TestSimpleWeibullSimpleWeibullKL(unittest.TestCase):
def test_divergence(self):
callback = SimpleWeibullSimpleWeibullKL(key, key)
input = SimpleWeibull(torch.ones(2, 2), torch.zeros(2, 2) + 0.5)
target = SimpleWeibull(torch.ones(2, 2), torch.ones(2, 2) * 5)
self.assertTrue(((callback.compute(input, target) - 3628803.7500).abs() < 0.0001).all())

1 change: 0 additions & 1 deletion torchbearer/variational/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def rsample(self, sample_shape=torch.Size()):

def log_prob(self, value):
"""Returns the log of the probability density/mass function evaluated at `value`.
Args:
value (torch.Tensor, Number): Value at which to evaluate log probabilty
"""
Expand Down
42 changes: 42 additions & 0 deletions torchbearer/variational/divergence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools

import torch
import torchbearer
from torchbearer import cite
import torchbearer.callbacks as callbacks
Expand All @@ -21,6 +22,17 @@
}
"""

weibullKL="""
@article{DBLP:journals/corr/Bauckhage14,
author = {Christian Bauckhage},
title = {Computing the Kullback-Leibler Divergence between two Generalized
Gamma Distributions},
journal = {CoRR},
volume = {abs/1401.6853},
year = {2014}
}
"""


class DivergenceBase(callbacks.Callback):
"""The :class:`DivergenceBase` class is an abstract base class which defines a series of useful methods for dealing
Expand Down Expand Up @@ -193,3 +205,33 @@ def compute(self, input, target):
mu_1, logvar_1 = input.mu, input.logvar
mu_2, logvar_2 = target.mu, target.logvar
return 0.5 * (logvar_1.exp() / logvar_2.exp() + (mu_2 - mu_1).pow(2) / logvar_2.exp() + logvar_2 - logvar_1 - 1)

@cite(weibullKL)
class SimpleWeibullSimpleWeibullKL(DivergenceBase):
"""A KL divergence between two SimpleWeibull (or similar) distributions.
.. note::
The distribution object must have lambda (scale) and k (shape) attributes
Args:
input_key: :class:`.StateKey` instance which will be mapped to the input distribution object.
target_key: :class:`.StateKey` instance which will be mapped to the target distribution object.
state_key: If not None, the value outputted by :meth:`compute` is stored in state with the given key.
"""
def __init__(self, input_key, target_key, state_key=None):
super(SimpleWeibullSimpleWeibullKL, self).__init__({'input': input_key, 'target': target_key}, state_key=state_key)
self.gamma=0.5772

def compute(self, input, target):
lambda_1, k_1 = input.l, input.k
lambda_2, k_2 = target.l, target.k
a = torch.log(k_1 / torch.pow(lambda_1, k_1))
b = torch.log(k_2 / torch.pow(lambda_2, k_2))
c = torch.mul((k_1 - k_2), (torch.log(lambda_1) - self.gamma / k_1))
n = k_2 / k_1 + 1
gammaf = torch.exp(torch.lgamma(n))
d = torch.mul(torch.pow(torch.div(lambda_1, lambda_2), k_2), gammaf)
loss = torch.mean(a - b + c + d - 1)
return loss

0 comments on commit 1a0bcf0

Please sign in to comment.