<a href="https://colab.research.google.com/github/seanreed1111/colab-demos/blob/master/class_truncated_double_exponential_wip.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pyro-ppl=='1.8.0'

Collecting pyro-ppl==1.8.0
  Downloading pyro_ppl-1.8.0-py3-none-any.whl (713 kB)
[?25l[K     |▌                               | 10 kB 16.4 MB/s eta 0:00:01[K     |█                               | 20 kB 8.9 MB/s eta 0:00:01[K     |█▍                              | 30 kB 7.4 MB/s eta 0:00:01[K     |█▉                              | 40 kB 6.6 MB/s eta 0:00:01[K     |██▎                             | 51 kB 4.0 MB/s eta 0:00:01[K     |██▊                             | 61 kB 4.7 MB/s eta 0:00:01[K     |███▏                            | 71 kB 5.2 MB/s eta 0:00:01[K     |███▊                            | 81 kB 4.1 MB/s eta 0:00:01[K     |████▏                           | 92 kB 4.5 MB/s eta 0:00:01[K     |████▋                           | 102 kB 5.0 MB/s eta 0:00:01[K     |█████                           | 112 kB 5.0 MB/s eta 0:00:01[K     |█████▌                          | 122 kB 5.0 MB/s eta 0:00:01[K     |██████                          | 133 kB 5.0 MB/s eta 0:00

modeled on: 
- https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/multivariate_studentt.py
- https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/logistic.py
- https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/inverse_gamma.py

In [None]:
import math

import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from torch.distributions.transforms import AffineTransform, ExpTransform

from pyro.distributions.torch import Exponential, TransformedDistribution
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape, broadcast_all

In [8]:
# Do this using transformed distribution, like in inverse gamma case
class TruncatedExponential(TransformedDistribution):
  """
    Creates a Double truncated Exponential parameterized by 
    rate :attr:`rate`, lower bound :attr:`low` and upper bound :attr:`high`
    so that the exponential pdf is defined only on the range (low, high)
    with low >0 and high > low > 0
   
    :param ~torch.Tensor rate: rate of Exponential distribution
    :param ~torch.Tensor low: lower bound of the distribution (low > 0)
    :param ~torch.Tensor high: upper bound of the distribution (high > low > 0)
  """

  has_rsample = True #flip to False if I cannot implement rsample successfully

  ##TODO where do I impose constraint low < high. Tried to implement as support per https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/affine_beta.py
  arg_constraints = {
      "rate": constraints.positive,
      "low": constraints.positive,
      "high": constraints.positive,
  }

  def __init__(self, rate, low=torch.tensor([0.]), high=torch.tensor([torch.inf]), validate_args=None):
    base_dist = Exponential(rate)
    super().__init__(base_dist,
                     [AffineTransform()], ##TODO FIX THIS
                     validate_args=validate_args
        )

  def expand(self, batch_shape, _instance=None):
    new = self._get_checked_instance(TruncatedExponential, _instance)
    return super().expand(batch_shape, _instance=new)

  @property
  def rate(self):
    return self.base_dist.rate
  
  def rsample(self, sample_shape=torch.Size()):
    shape = self._extended_shape(sample_shape)
    raise NotImplementedError("sample has not been implemented")
  
  def log_prob(self, value):
    if self._validate_args: self._validate_sample(value)
    raise NotImplementedError("log_prob has not been implemented")

  def cdf(self, value):
    if self._validate_args: self._validate_sample(value)
    raise NotImplementedError("cdf has not been implemented")

  def icdf(self, value):
    if self._validate_args: self._validate_sample(value)
    raise NotImplementedError("icdf has not been implemented")

  def expand(self, batch_shape, _instance=None):
    new = self._get_checked_instance(TruncatedExponential, _instance)
    batch_shape = torch.Size(batch_shape)
    new.loc = self.loc.expand(batch_shape)
    new.scale = self.scale.expand(batch_shape)
    super(TruncatedExponential, new).__init__(batch_shape, validate_args=False)
    new._validate_args = self._validate_args
    return new

  # following https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/affine_beta.py but is that only for integers??
  @constraints.dependent_property
  def support(self):
      return constraints.interval(self.low, self.high)