In [1]:
"""
Interpreting generic statements with RSA models of pragmatics.

Taken from:
[0] http://forestdb.org/models/generics.html
[1] https://gscontras.github.io/probLang/chapters/07-generics.html
"""

import argparse
import collections
import numbers

import torch
from search_inference import HashingMarginal, Search, memoize

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

torch.set_default_dtype(torch.float64)  # double precision for numerical stability

# Models

In [2]:
def Marginal(fn):
    return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))

In [3]:
# hashable params
Params = collections.namedtuple("Params", ["theta", "gamma", "delta"])

In [4]:
def discretize_beta_pdf(bins, gamma, delta):
    """
    discretized version of the Beta pdf used for approximately integrating via Search
    """
    shape_alpha = gamma * delta
    shape_beta = (1.0 - gamma) * delta
    return torch.tensor(
        list(
            map(
                lambda x: (x ** (shape_alpha - 1)) * ((1.0 - x) ** (shape_beta - 1)),
                bins,
            )
        )
    )

In [5]:
@Marginal
def structured_prior_model(params):
    propertyIsPresent = (
        pyro.sample("propertyIsPresent", dist.Bernoulli(params.theta)).item() == 1
    )
    if propertyIsPresent:
        # approximately integrate over a beta by enumerating over bins
        beta_bins = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99]
        ix = pyro.sample(
            "bin",
            dist.Categorical(
                probs=discretize_beta_pdf(beta_bins, params.gamma, params.delta)
            ),
        )
        return beta_bins[ix]
    return 0

In [6]:
def threshold_prior():
    threshold_bins = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    ix = pyro.sample(
        "threshold", dist.Categorical(logits=torch.zeros(len(threshold_bins)))
    )
    return threshold_bins[ix]

In [7]:
def utterance_prior():
    utterances = ["generic is true", "mu"]
    ix = pyro.sample("utterance", dist.Categorical(logits=torch.zeros(len(utterances))))
    return utterances[ix]

In [8]:
def meaning(utterance, state, threshold):
    if isinstance(utterance, numbers.Number):
        return state == utterance
    if utterance == "generic is true":
        return state > threshold
    if utterance == "generic is false":
        return state <= threshold
    if utterance == "mu":
        return True
    if utterance == "some":
        return state > 0
    if utterance == "most":
        return state >= 0.5
    if utterance == "all":
        return state >= 0.99
    return True

## Listener0

In [9]:
@Marginal
def listener0(utterance, threshold, prior):
    state = pyro.sample("state", prior)
    m = meaning(utterance, state, threshold)
    pyro.factor("listener0_true", 0.0 if m else -99999.0)
    return state

## Speaker1

In [10]:
@Marginal
def speaker1(state, threshold, prior):
    s1Optimality = 5.0
    utterance = utterance_prior()
    L0 = listener0(utterance, threshold, prior)
    with poutine.scale(scale=torch.tensor(s1Optimality)):
        pyro.sample("L0_score", L0, obs=state)
    return utterance

## Listener1

In [11]:
@Marginal
def listener1(utterance, prior):
    state = pyro.sample("state", prior)
    threshold = threshold_prior()
    S1 = speaker1(state, threshold, prior)
    pyro.sample("S1_score", S1, obs=utterance)
    return state

## Speaker2

In [12]:
@Marginal
def speaker2(prevalence, prior):
    utterance = utterance_prior()
    wL1 = listener1(utterance, prior)
    pyro.sample("wL1_score", wL1, obs=prevalence)
    return utterance

# Playground

In [13]:
def inspect_support(model: HashingMarginal, name: str) -> None:
    print(name)
    for support in model.enumerate_support():
        print("  ", support, model.log_prob(support).exp().item())
    return None

In [14]:
hasWingsERP = structured_prior_model(Params(0.5, 0.99, 10.0))
inspect_support(hasWingsERP, "hasWingsERP")

wingsPosterior = listener1("generic is true", hasWingsERP)
inspect_support(wingsPosterior, "wingsPosterior")

Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
Params(theta=0.5, gamma=0.99, delta=10.0)
hasWingsERP
   0 0.5
   0.01 1.1102230246251573e-16
   0.1 1.1245284798484416e-11
   0.2 5.972754861819482e-09
   0.3 2.486449530091711e-07
   0.4 3.6964642884811183e-06
   0.5 3.173574895865843e-05
   0.6 0.00019655353255805638
   0.7 0.0010040426106093096
   0.8 0.004746381094690639
   0.9 0.0252666143327844
   0.99 0.4687507215871573
generic is true <search_inference.HashingMarginal object at 0x7f8244287190>
0 0.0 <search_inference.HashingMarginal object at 0x7f8244287190>
generic is true

In [None]:
laysEggsERP = structured_prior_model(Params(0.5, 0.5, 10.0))
inspect_support(laysEggsERP, "laysEggsERP")

eggsPosterior = listener1("generic is true", laysEggsERP)
inspect_support(eggsPosterior, "eggsPosterior")

In [None]:
carriesMalariaERP = structured_prior_model(Params(0.1, 0.01, 2.0))
inspect_support(carriesMalariaERP, "carriesMalariaERP")

malariaPosterior = listener1("generic is true", carriesMalariaERP)
inspect_support(malariaPosterior, "malariaPosterior")

In [None]:
areFemaleERP = structured_prior_model(Params(0.99, 0.5, 50.0))
inspect_support(areFemaleERP, "areFemaleERP")

femalePosterior = listener1("generic is true", areFemaleERP)
inspect_support(femalePosterior, "femalePosterior")