In [287]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [288]:
import argparse
import collections
import numbers

import torch
from search_inference import HashingMarginal, Search, memoize

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import config_enumerate, TraceEnum_ELBO
from pyro.ops.indexing import Vindex

torch.set_default_dtype(torch.float64)  # double precision for numerical stability
torch.manual_seed(42)

<torch._C.Generator at 0x7f8d3c5e2e70>

In [307]:
from vectorized_search import VectoredSearch, VectoredHashingMarginal

def Marginal(fn):
    marginal = VectoredHashingMarginal(VectoredSearch(config_enumerate(fn)))
    return memoize(lambda *args: marginal.run(*args)))

In [290]:
utterances = [
    "generic is true", "generic is false",
    "mu", "some", "most", "all",
]

In [308]:
Params = collections.namedtuple("Params", ["theta", "gamma", "delta"])

beta_bins = torch.tensor([0., 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99])

@Marginal
def structured_prior(params: Params) -> torch.Tensor:
    # computing the Beta pdf for discretized bins above for enumerated Search
    shape_alpha = params.gamma * params.delta - 1
    shape_beta  = (1. - params.gamma) * params.delta - 1
    discrete_bins = (beta_bins ** shape_alpha) * ((1. - beta_bins) ** shape_beta) * params.theta
    discrete_bins[0] = (1 - params.theta)
    idx = pyro.sample("bin", dist.Categorical(probs=discrete_bins / discrete_bins.sum()))

    return beta_bins[idx]

In [309]:
wings_prior_params = Params(theta=0.5, gamma=0.99, delta=10.0)
wings_prior = structured_prior(wings_prior_params)

for el in wings_prior.enumerate_support():
    print(el.item(), wings_prior.log_prob(el).exp().item())

0.0 0.015988904498871442
0.01 2.2204460492503185e-16
0.1 2.213097002755768e-11
0.2 1.175451410947732e-08
0.3 4.893387852027631e-07
0.4 7.274723747978211e-06
0.5 6.245665819871659e-05
0.6 0.00038682171379413966
0.7 0.00197597813839096
0.8 0.009340983321304765
0.9 0.0497252576984154
0.99 0.922511822131846


In [293]:
def utterance_prior() -> torch.Tensor:
    utts = torch.arange(0, len(utterances), 1)
    probs = torch.ones_like(utts) / len(utts)
    idx = pyro.sample("utterance", dist.Categorical(probs=probs))
    return utts[idx]

In [294]:
def threshold_prior() -> torch.Tensor:
    bins = torch.arange(0.0, 1.0, 0.1)
    idx = pyro.sample("threshold", dist.Categorical(logits=torch.zeros_like(bins)))
    return bins[idx]

In [295]:
def meaning(utterance: torch.Tensor, state: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
    possible_evals = {
        "as_genT": (state > threshold),
        "as_genF": (state <= threshold),
        "is_mu"  : torch.full_like(state, True, dtype=bool),
        "is_some": (state > 0),
        "is_most": (state >= 0.5),
        "is_all" : (state >= 0.99),
        "as_num" : (state == utterance),
        "default": torch.full_like(state, True, dtype=bool),
    }

    meanings = torch.stack(list(possible_evals.values()))

    while utterance.ndim < meanings.ndim:  # expand utterance to be used as an indexer
        utterance = utterance[None]

    return torch.gather(meanings, dim=0, index=utterance.long()).float().squeeze()

# Listener 0

In [314]:
@Marginal
def listener0(utterances: torch.Tensor, thresholds: torch.Tensor, prior: HashingMarginal) -> torch.Tensor:
    state = pyro.sample(f"state", prior)
    means = meaning(utterances, state, thresholds)
    pyro.factor(f"listener0-true", torch.where(means == 1., 0., -99_999.))
    return state

In [315]:
wings_posterior = listener0(torch.tensor([1]), torch.tensor([0.1]), wings_prior)
for el in wings_posterior.enumerate_support():
    print(el, wings_posterior.log_prob(el).exp().item())


tensor(0.9900) 1.0


In [301]:
wings_posterior.trace_dist.exec_traces[0].nodes

OrderedDict([('_INPUT',
              {'name': '_INPUT',
               'type': 'args',
               'args': (tensor([1]), tensor([0.1000])),
               'kwargs': {}}),
             ('state',
              {'type': 'sample',
               'name': 'state',
               'fn': <vectorized_search.VectoredHashingMarginal at 0x7f8d24a94d90>,
               'is_observed': False,
               'args': (),
               'kwargs': {},
               'value': tensor([0.9000]),
               'infer': {'enumerate': 'parallel', 'expand': True},
               'scale': 1.0,
               'mask': None,
               'cond_indep_stack': (),
               'done': True,
               'stop': True,
               'continuation': None,
               'unscaled_log_prob': tensor([-3.0012]),
               'log_prob': tensor([-3.0012]),
               'log_prob_sum': tensor(-3.0012)}),
             ('listener0-true',
              {'type': 'sample',
               'name': 'listener0-true',
  