In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import altair as alt
import pandas as pd

In [3]:
import torch
import numpy as np
from torch import distributions as pyd
from torch import nn
import torch.nn.functional as F
from torch import optim
import math

In [4]:
class TanhTransform(pyd.transforms.Transform):
    domain = pyd.constraints.real
    codomain = pyd.constraints.interval(-1.0, 1.0)
    bijective = True
    sign = +1

    def __init__(self, cache_size=1):
        super().__init__(cache_size=cache_size)

    @staticmethod
    def atanh(x):
        return 0.5 * (x.log1p() - (-x).log1p())

    def __eq__(self, other):
        return isinstance(other, TanhTransform)

    def _call(self, x):
        return x.tanh()

    def _inverse(self, y):
        # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
        # one should use `cache_size=1` instead
        return self.atanh(y)

    def log_abs_det_jacobian(self, x, y):
        # We use a formula that is more numerically stable, see details in the following link
        # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
        return 2. * (math.log(2.) - x - F.softplus(-2. * x))


class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
    def __init__(self, loc, scale):
        self.loc = loc
        self.scale = scale

        self.base_dist = pyd.Normal(loc, scale)
        transforms = [TanhTransform()]
        super().__init__(self.base_dist, transforms)

    @property
    def mean(self):
        mu = self.loc
        for tr in self.transforms:
            mu = tr(mu)
        return mu

In [5]:
x = torch.linspace(-0.999, 0.999, 1000)

In [6]:
def render_dist(dist):
    data = pd.DataFrame({'x': x, 'log_prob': dist.log_prob(x).exp().detach()})
    return alt.Chart(data).mark_line().encode(x='x', y='log_prob')

In [7]:
render_dist(SquashedNormal(0, 0.2))

In [9]:
data = SquashedNormal(2, 0.1).sample((1000,))

In [10]:
mu = torch.tensor(0., requires_grad=True)
log_std = torch.tensor(0., requires_grad=True)

optimizer = optim.Adam([mu, log_std], lr=1e-2)
for i in range(1000):
    train_dist = SquashedNormal(mu, log_std.exp())
    log_prob = train_dist.log_prob(data)
    loss = -log_prob.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [11]:
render_dist(SquashedNormal(0, 1)) | render_dist(SquashedNormal(mu, log_std.exp()))

In [12]:
mu, log_std.exp()

(tensor(2.0072, requires_grad=True), tensor(0.0999, grad_fn=<ExpBackward>))