In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import statsmodels.api as sm
from scipy import stats

import torch
import pyro
from pyro import distributions as dist
from pyro.infer.mcmc import NUTS, MCMC
from pyro.infer.abstract_infer import EmpiricalMarginal, TracePredictive

## プルリクマージされておらず(2019/1/31時点)

In [24]:
from torch.distributions import constraints

from pyro.distributions.torch_distribution import TorchDistribution
class CensoredDistribution(TorchDistribution):

    def __init__(self, base_distribution, upper_lim=float('inf'), lower_lim=float('-inf'), validate_args=None):
        # Log-prob only computed correctly for univariate base distribution
        assert base_distribution.event_dim == 0 or (
                base_distribution.event_dim == 1 and base_distribution.event_shape[0] == 1)
        self.base_dist = base_distribution
        self.upper_lim = upper_lim
        self.lower_lim = lower_lim

        super(CensoredDistribution, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape,
                                                   validate_args=validate_args)

    @constraints.dependent_property
    def support(self):
        return constraints.interval(self.lower_lim, self.upper_lim)

    def sample(self, sample_shape=torch.Size()):
        with torch.no_grad():
            x = self.base_dist.sample(sample_shape)
            x[x > self.upper_lim] = self.upper_lim
            x[x < self.lower_lim] = self.lower_lim
            return x

    def rsample(self, sample_shape=torch.Size()):
        x = self.base_dist.sample(sample_shape)
        x[x > self.upper_lim] = self.upper_lim
        x[x < self.lower_lim] = self.lower_lim

    def log_prob(self, value):
        """
        Scores the sample by giving a probability density relative to a new base measure.
        The new base measure places an atom at `self.upper_lim` and `self.lower_lim`, and
        has Lebesgue measure on the intervening interval.
        Thus, `log_prob(self.lower_lim)` and `log_prob(self.upper_lim)` represent probabilities
        as for discrete distributions. `log_prob(x)` in the interior represent regular
        pdfs with respect to Lebesgue measure on R.
        **Note**: `log_prob` scores from distributions with different censoring are not
        comparable.
        """
        log_prob = self.base_dist.log_prob(value)
        upper_cdf = 1. - self.base_dist.cdf(self.upper_lim)
        lower_cdf = self.base_dist.cdf(self.lower_lim)

        log_prob[value == self.upper_lim] = torch.log(upper_cdf.expand_as(log_prob)[value == self.upper_lim])
        log_prob[value > self.upper_lim] = float('-inf')
        log_prob[value == self.lower_lim] = torch.log(lower_cdf.expand_as(log_prob)[value == self.lower_lim])
        log_prob[value < self.lower_lim] = float('-inf')

        return log_prob

    def cdf(self, value):
        if self._validate_args:
            self.base_dist._validate_sample(value)
        cdf = self.base_dist.cdf(value)
        cdf[value >= self.upper_lim] = 1.
        cdf[value < self.lower_lim] = 0.

    def icdf(self, value):
        # Is this even possible?
        raise NotImplemented

In [30]:
data = pd.read_csv("input/data-protein.txt")

In [31]:
data

Unnamed: 0,Y
0,<25
1,32.3
2,<25
3,28.3
4,30.8
5,35.2


In [32]:
idx = data.Y.str.contains("<")
Y_obs = data[~idx].Y.astype(np.float)
L = data[idx].Y.str.replace("<", "").astype(np.float)

In [33]:
Y_obs

1    32.3
3    28.3
4    30.8
5    35.2
Name: Y, dtype: float64

In [34]:
L

0    25.0
2    25.0
Name: Y, dtype: float64

In [35]:
y_obs = torch.tensor(Y_obs.values).float()
l = torch.tensor(L.values).float()

## モデル式

- 打ち切りがない測定の場合

$Y[n] \sim Normal(\mu, \sigma_Y)$

- 打ち切りがある測定の場合

$y[n] \sim Normal(\mu, \sigma_Y)$，ただし$y[n] < L$

In [42]:
def model(Y_obs, L):
    mu = pyro.sample("mu", dist.Normal(0, 100))
    sigma_y = pyro.sample("sigma_y", dist.Uniform(0, 100))
    
    with pyro.plate("obs_data", len(Y_obs)):
        pyro.sample("obs", dist.Normal(mu, sigma_y), obs=Y_obs)
    with pyro.plate("cens_data", len(L)):
        pyro.sample("cens", CensoredDistribution(dist.Normal(mu, sigma_y), upper_lim=25), obs=L)

In [43]:
kernel = NUTS(model, adapt_step_size=True, jit_compile=True, ignore_jit_warnings=True)
posterior = MCMC(kernel, num_chains=4, num_samples=1000, warmup_steps=1000).run(y_obs, l)

HBox(children=(IntProgress(value=0, description='Warmup [1]', max=2000, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Warmup [2]', max=2000, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Warmup [3]', max=2000, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Warmup [4]', max=2000, style=ProgressStyle(description_width=…

[ERROR LOG CHAIN:0]
 Trace Shapes:    
  Param Sites:    
 Sample Sites:    
       mu dist   |
         value   |
  sigma_y dist   |
         value   |
 obs_data dist   |
         value 4 |
      obs dist 4 |
         value 4 |
cens_data dist   |
         value 2 |
Traceback (most recent call last):
  File "/Users/makora/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 147, in __call__
    ret = self.fn(*args, **kwargs)
  File "<ipython-input-42-d2bcda8a2a67>", line 8, in model
    pyro.sample("cens", CensoredDistribution(dist.Normal(mu, sigma_y), upper_lim=25), obs=L)
  File "/Users/makora/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/pyro/primitives.py", line 98, in sample
    apply_stack(msg)
  File "/Users/makora/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 190, in apply_stack
    frame._process_message(msg)
  File "/Users/makora/.pyenv/versions/miniconda3-latest/lib

NotImplementedError: 
 Trace Shapes:    
  Param Sites:    
 Sample Sites:    
       mu dist   |
         value   |
  sigma_y dist   |
         value   |
 obs_data dist   |
         value 4 |
      obs dist 4 |
         value 4 |
cens_data dist   |
         value 2 |