Skip to content

Commit

Permalink
Merge pull request #64 from probcomp/20151016-fsaad-jointpdf
Browse files Browse the repository at this point in the history
20151016 fsaad jointpdf
  • Loading branch information
F Saad committed Oct 23, 2015
2 parents 7e65886 + 517c5a3 commit 373b884
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 12 deletions.
5 changes: 4 additions & 1 deletion src/EngineTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def row_structural_typicality(self, X_L_list, X_D_list, row_id):
def column_structural_typicality(self, X_L_list, col_id):
return None

def predictive_probability(self, M_c, X_L_list, X_D_list, T, q, n=1):
def predictive_probability(self, M_c, X_L, X_D, T, Q, n=1):
return None

def predictive_probability_multistate(self, M_c, X_L_list, X_D_list, T, Q, n=1):
return None

def similarity(self, M_c, X_L_list, X_D_list, given_row_id, target_row_id, target_columns=None):
Expand Down
52 changes: 48 additions & 4 deletions src/LocalEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,11 @@ def simple_predictive_probability(self, M_c, X_L, X_D, Y, Q):
:type X_L: dict
:param X_D: the particular cluster assignments of each row in each view
:type X_D: list of lists
:param Y: A list of constraints to apply when sampling. Each constraint
:param Y: A list of constraints to apply when querying. Each constraint
is a triplet of (r, d, v): r is the row index, d is the column
index and v is the value of the constraint
:type Y: list of lists
:param Q: A list of values to sample. Each value is triplet of (r, d, v):
:param Q: A list of values to query. Each value is triplet of (r, d, v):
r is the row index, d is the column index, and v is the value at
which the density is evaluated.
:type Q: list of lists
Expand All @@ -347,11 +347,11 @@ def simple_predictive_probability_multistate(self, M_c, X_L_list, X_D_list, Y, Q
:type X_L_list: list of dict
:param X_D_list: list of the particular cluster assignments of each row in each view
:type X_D_list: list of list of lists
:param Y: A list of constraints to apply when sampling. Each constraint
:param Y: A list of constraints to apply when querying. Each constraint
is a triplet of (r,d,v): r is the row index, d is the column
index and v is the value of the constraint
:type Y: list of lists
:param Q: A list of values to sample. Each value is triplet of (r,d,v):
:param Q: A list of values to query. Each value is triplet of (r,d,v):
r is the row index, d is the column index, and v is the value at
which the density is evaluated.
:type Q: list of lists
Expand All @@ -360,6 +360,50 @@ def simple_predictive_probability_multistate(self, M_c, X_L_list, X_D_list, Y, Q
"""
return su.simple_predictive_probability_multistate(M_c, X_L_list, X_D_list, Y, Q)

def predictive_probability(self, M_c, X_L, X_D, Y, Q):
"""Calculate the probability of cellS jointly taking values given a latent state
:param M_c: The column metadata
:type M_c: dict
:param X_L: the latent variables associated with the latent state
:type X_L: dict
:param X_D: the particular cluster assignments of each row in each view
:type X_D: list of lists
:param Y: A list of constraints to apply when querying. Each constraint
is a triplet of (r, d, v): r is the row index, d is the column
index and v is the value of the constraint
:type Y: list of lists
:param Q: A list of values to query. Each value is triplet of (r, d, v):
r is the row index, d is the column index, and v is the value at
which the density is evaluated.
:type Q: list of lists
:returns: float -- joint log probability of the values specified by Q
"""
return su.predictive_probability(M_c, X_L, X_D, Y, Q)

def predictive_probability_multistate(self, M_c, X_L_list, X_D_list, Y, Q):
"""Calculate the probability of cellS jointly taking values given a latent state
:param M_c: The column metadata
:type M_c: dict
:param X_L_list: list of the latent variables associated with the latent state
:type X_L_list: list of dict
:param X_D_list: list of the particular cluster assignments of each row in each view
:type X_D_list: list of list of lists
:param Y: A list of constraints to apply when querying. Each constraint
is a triplet of (r,d,v): r is the row index, d is the column
index and v is the value of the constraint
:type Y: list of lists
:param Q: A list of values to query. Each value is triplet of (r,d,v):
r is the row index, d is the column index, and v is the value at
which the density is evaluated.
:type Q: list of lists
:returns: float -- joint log probabilities of the values specified by Q
"""
return su.predictive_probability_multistate(M_c, X_L_list, X_D_list, Y, Q)

def mutual_information(self, M_c, X_L_list, X_D_list, Q, n_samples=1000):
"""
Return the estimated mutual information for each pair of columns on Q given
Expand Down
86 changes: 79 additions & 7 deletions src/tests/unit_tests/test_pred_prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,10 @@
# limitations under the License.
#

import argparse
import pytest
import random
import sys
import tempfile

import numpy

import crosscat.tests.synthetic_data_generator as sdg
from crosscat.LocalEngine import LocalEngine
import crosscat.cython_code.State as State

'''
This test suite ensures that invoking simple_predictive_probability_observed
Expand Down Expand Up @@ -157,3 +150,82 @@ def test_simple_predictive_probability_unobserved(seed=0):
Y = [(N_ROWS, 2, 4), (N_ROWS+1, 2, 5)]
with pytest.raises(IndexError):
vals = engine.simple_predictive_probability(M_c, X_L, X_D, Y, Q[1])

def test_predictive_probability_observed(seed=0):
# TODO
pass

def test_predictive_probability_unobserved(seed=0):
# This function tests the predictive probability for the joint distirbution.
# Throughout, we will check that the result is the same for the joint and
# simple calls.
T, M_r, M_c, X_L, X_D, engine = quick_le(seed)

# Hypothetical column number should throw an error.
Q = [(N_ROWS, 1, 1.5), (N_ROWS, 10, 2)]
Y = []
with pytest.raises(ValueError):
vals = engine.predictive_probability(M_c, X_L, X_D, Y, Q)

# Inconsistent row numbers should throw an error.
Q = [(N_ROWS, 1, 1.5), (N_ROWS-1, 10, 2)]
Y = []
with pytest.raises(ValueError):
vals = engine.predictive_probability(M_c, X_L, X_D, Y, Q)

# Duplicate column numbers should throw an error,
Q = [(N_ROWS, 1, 1.5), (N_ROWS, 1, 2)]
Y = []
with pytest.raises(ValueError):
val = engine.predictive_probability(M_c, X_L, X_D, Y, Q)

# Different row numbers should throw an error.
Q = [(N_ROWS, 0, 1.5), (N_ROWS+1, 1, 2)]
Y = [(N_ROWS, 1, 1.5), (N_ROWS, 2, 3)]
with pytest.raises(Exception):
val = engine.predictive_probability(M_c, X_L, X_D, Y, Q[0])

# Inconsistent with constraints should be negative infinity.
Q = [(N_ROWS, 1, 1.5), (N_ROWS, 0, 1.3)]
Y = [(N_ROWS, 1, 1.6)]
val = engine.predictive_probability(M_c, X_L, X_D, Y, Q)
assert val == -float('inf')
assert isinstance(val, float)

# Consistent with constraints should be log(1) == 0.
Q = [(N_ROWS, 0, 1.3)]
Y = [(N_ROWS, 0, 1.3)]
val = engine.predictive_probability(M_c, X_L, X_D, Y, Q)
assert val == 0

# Consistent with constraints should not impact other queries.
Q = [(N_ROWS, 1, 1.5), (N_ROWS, 0, 1.3)]
Y = [(N_ROWS, 1, 1.5), (N_ROWS, 2, 3)]
val_0 = engine.predictive_probability(M_c, X_L, X_D, Y, Q)
val_1 = engine.predictive_probability(M_c, X_L, X_D, Y, Q[1:])
assert val_0 == val_1

# Predictive and simple should be the same in univariate case (cont).
Q = [(N_ROWS, 0, 0.5)]
Y = [(0, 0, 1), (N_ROWS/2, 4, 5), (N_ROWS, 1, 0.5), (N_ROWS+1, 0, 1.2)]
val_0 = engine.predictive_probability(M_c, X_L, X_D, Y, Q)
val_1 = engine.simple_predictive_probability(M_c, X_L, X_D, Y, Q)
assert val_0 == val_1

# Predictive and simple should be the same in univariate case (disc).
Q = [(N_ROWS, 2, 1)]
Y = [(0, 0, 1), (N_ROWS/2, 4, 5), (N_ROWS, 1, 0.5), (N_ROWS+1, 0, 1.2)]
val_0 = engine.predictive_probability(M_c, X_L, X_D, Y, Q)
val_1 = engine.simple_predictive_probability(M_c, X_L, X_D, Y, Q)
assert val_0 == val_1

# Do some full joint queries, all on the same row.
Q = [(N_ROWS, 3, 4), (N_ROWS, 4, 1.3)]
Y = [(N_ROWS, 0, 1), (N_ROWS, 1, -0.7), (N_ROWS, 2, 3)]
val = engine.predictive_probability(M_c, X_L, X_D, Y, Q)
assert isinstance(val, float)

Q = [(N_ROWS, 0, 1), (N_ROWS, 1, -0.7), (N_ROWS, 2, 3)]
Y = [(N_ROWS, 3, 4), (N_ROWS, 4, 1.3)]
val = engine.predictive_probability(M_c, X_L, X_D, Y, Q)
assert isinstance(val, float)
65 changes: 65 additions & 0 deletions src/utils/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import copy
from collections import Counter
import numpy
import itertools
#
import crosscat.cython_code.ContinuousComponentModel as CCM
import crosscat.cython_code.MultinomialComponentModel as MCM
Expand All @@ -40,6 +41,57 @@ def __setattr__(self, key, value):

Constraints = Bunch

def predictive_probability(M_c, X_L, X_D, Y, Q):
# Evaluates the joint logpdf of crosscat columns. This is acheived by
# invoking column_value_probability on univariate columns with
# cascading the constraints (the chain rule).

# Q (query): list of three element tuples where each tuple, (r,c,x)
# contains a row r; column, c; value x. All rows must be the same.
# Y (contraints), follows an identical format.

# The current interface does not allow the query columns to have different
# row numbers, so this function will ensure the same constraint, pending
# a formalization of the semantic meaning of predictive_probability of
# arbitrary patterns of cells.
queries = dict()
for (row, col, val) in Q:
if row != Q[0][0]:
raise ValueError('Cannot specify different query rows.')
if (row, col) in queries:
raise ValueError('Cannot specify duplicate query columns.')
if len(M_c['column_metadata']) <= col:
raise ValueError('Cannot specify hypothetical query column.')
queries[(row, col)] = val
# Ensure consistency for nodes in both query and constraints.
# This behavior is correct, even for real-valued datatypes. Conditional
# probability is itself a complex topic, but consider random
# variable X continuous. Then the conditional density of X f(s|X=t) is
# 1 if s==t and 0 otherwise. Note change of the dominating measure from
# Lebesgue to counting. The argument is not rigorous but correct.
ignore = set()
constraints = set()
for (row, col, val) in Y:
if (row, col) in constraints:
raise ValueError('Cannot specify duplicate constraint row, column.')
if (row, col) in queries:
if queries[(row, col)] == val:
ignore.add(col)
else:
return float('-inf')
constraints.add((row, col))
Y_prime = list(Y)
# Chain rule.
prob = 0
for query in Q:
if query[1] in ignore:
continue
r = simple_predictive_probability(M_c, X_L, X_D, Y_prime, [query])
prob += float(r)
Y_prime.append(query)
return prob


# Q is a list of three element tuples where each tuple, (r,c,x) contains a
# row, r; a column, c; and a value x. The contraints, Y follow an identical format.
# Returns a numpy array where each entry, A[i] is the probability for query i given
Expand Down Expand Up @@ -177,6 +229,19 @@ def simple_predictive_probability_multistate(M_c, X_L_list, X_D_list, Y, Q):
# = logsumexp(logprobs) - log(len(logprobs))
return logsumexp(logprobs) - numpy.log(len(logprobs))

def predictive_probability_multistate(M_c, X_L_list, X_D_list, Y, Q):
"""
Returns the predictive probability, averaged over each sample.
"""
logprobs = [float(predictive_probability(M_c, X_L, X_D, Y, Q))
for X_L, X_D in zip(X_L_list, X_D_list)]
# probs = map(exp, logprobs)
# log(mean(probs)) = log(sum(probs) / len(probs))
# = log(sum(probs)) - log(len(probs))
# = log(sum(map(exp, probs))) - log(len(probs))
# = logsumexp(logprobs) - log(len(logprobs))
return logsumexp(logprobs) - numpy.log(len(logprobs))


#############################################################################

Expand Down

0 comments on commit 373b884

Please sign in to comment.