Skip to content

Commit

Permalink
Add subsampling in Crosscat.
Browse files Browse the repository at this point in the history
Notation is SUBSAMPLE(1000) or SUBSAMPLE(OFF) in generator schema.

First 1000 rows will be fed to Crosscat.  BQL functions on other rows
will be ephemerally inserted into Crosscat each time to compute the
function.  Currently no way to permanently add rows to the subsample
or do analysis on rows not inserted.
  • Loading branch information
riastradh-probcomp committed Jun 4, 2015
1 parent 58c667e commit 0780f7c
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 24 deletions.
206 changes: 186 additions & 20 deletions src/crosscat.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,31 @@
);
'''

crosscat_schema_2to3 = '''
UPDATE bayesdb_metamodel SET version = 3 WHERE name = 'crosscat';
CREATE TABLE bayesdb_crosscat_subsampled (
generator_id INTEGER NOT NULL PRIMARY KEY
REFERENCES bayesdb_crosscat_metadata
);
-- Generator-wide subsample, not per-model.
CREATE TABLE bayesdb_crosscat_subsample (
generator_id INTEGER NOT NULL
REFERENCES bayesdb_crosscat_subsampled,
sql_rowid INTEGER NOT NULL,
cc_row_id INTEGER NOT NULL,
PRIMARY KEY(generator_id, sql_rowid ASC),
UNIQUE(generator_id, cc_row_id ASC)
-- Can't express the desired foreign key constraint,
-- FOREIGN KEY(sql_rowid) REFERENCES <table of generator>(rowid),
-- for two reasons:
-- 1. No way for constraint to have data-dependent table.
-- 2. Can't refer to implicit rowid in sqlite3 constraints.
-- So we'll just hope nobody botches it.
);
'''

class CrosscatMetamodel(metamodel.IBayesDBMetamodel):
"""Crosscat metamodel for BayesDB.
Expand All @@ -121,8 +146,11 @@ class CrosscatMetamodel(metamodel.IBayesDBMetamodel):
with names that begin ``crosscat_``.
"""

def __init__(self, crosscat):
def __init__(self, crosscat, subsample=None):
if subsample is None:
subsample = 1000 # XXX
self._crosscat = crosscat
self._subsample = subsample

def _crosscat_cache_nocreate(self, bdb):
if bdb.cache is None:
Expand Down Expand Up @@ -179,7 +207,23 @@ def _crosscat_data(self, bdb, generator_id, M_c):
columns = list(bdb.sql_execute(columns_sql, (generator_id,)))
colnames = [name for name, _colno in columns]
qcns = map(sqlite3_quote_name, colnames)
cursor = bdb.sql_execute('SELECT %s FROM %s' % (','.join(qcns), qt))
subsampled = bdb.sql_execute('''
SELECT COUNT(*) FROM bayesdb_crosscat_subsampled
WHERE generator_id = ?
''', (generator_id,)).next()[0]
sql = None
params = None
if subsampled == 0:
sql = 'SELECT %s FROM %s' % (','.join(qcns), qt)
params = ()
else:
sql = '''
SELECT %s FROM %s AS t, bayesdb_crosscat_subsample AS s
WHERE s.generator_id = ?
AND s.sql_rowid = t._rowid_
''' % (','.join('t.%s' % (qcn,) for qcn in qcns), qt)
params = (generator_id,)
cursor = bdb.sql_execute(sql, params)
return [[crosscat_value_to_code(bdb, generator_id, M_c, colno, value)
for value, (_name, colno) in zip(row, columns)]
for row in cursor]
Expand Down Expand Up @@ -235,6 +279,68 @@ def _crosscat_latent_data(self, bdb, generator_id, modelno):
return [statum[1] for statum
in self._crosscat_latent_stata(bdb, generator_id, modelno)]

def _crosscat_get_row(self, bdb, generator_id, rowid, X_L_list, X_D_list):
[row_id], X_L_list, X_D_list = \
self._crosscat_get_rows(bdb, generator_id, [rowid], X_L_list,
X_D_list)
return row_id, X_L_list, X_D_list

def _crosscat_get_rows(self, bdb, generator_id, rowids, X_L_list,
X_D_list):
subsampled = bdb.sql_execute('''
SELECT COUNT(*) FROM bayesdb_crosscat_subsampled
WHERE generator_id = ?
''', (generator_id,)).next()[0]
if subsampled == 0:
return [rowid - 1 for rowid in rowids], X_L_list, X_D_list
row_ids = [None] * len(rowids)
index = dict((rowid, i) for i, rowid in enumerate(rowids))
cursor = bdb.sql_execute('''
SELECT sql_rowid, cc_row_id FROM bayesdb_crosscat_subsample
WHERE generator_id = ?
AND sql_rowid IN (%s)
''' % (','.join('%d' % (rowid,) for rowid in rowids)),
(generator_id,))
for rowid, row_id in cursor:
row_ids[index[rowid]] = row_id
del index[rowid]
if 0 < len(index):
rowids = sorted(index.keys())
table_name = core.bayesdb_generator_table(bdb, generator_id)
qt = sqlite3_quote_name(table_name)
modelled_column_names = \
core.bayesdb_generator_column_names(bdb, generator_id)
qcns = ','.join(map(sqlite3_quote_name, modelled_column_names))
qrowids = ','.join('%d' % (rowid,) for rowid in rowids)
M_c = self._crosscat_metadata(bdb, generator_id)
cursor = bdb.sql_execute('''
SELECT %s FROM %s WHERE _rowid_ IN (%s) ORDER BY _rowid_ ASC
''' % (qcns, qt, qrowids))
colnos = core.bayesdb_generator_column_numbers(bdb, generator_id)
rows = [[crosscat_value_to_code(bdb, generator_id, M_c, colno, x)
for colno, x in zip(colnos, row)]
for row in cursor]
T = self._crosscat_data(bdb, generator_id, M_c)
X_L_list, X_D_list, T = self._crosscat.insert(
M_c=M_c,
T=T,
X_L_list=X_L_list,
X_D_list=X_D_list,
new_rows=rows,
)
for r0, r1 in \
zip(T, self._crosscat_data(bdb, generator_id, M_c) + rows):
assert all(x0 == x1 or (math.isnan(x0) and math.isnan(x1))
for x0, x1 in zip(r0, r1))
next_row_id = bdb.sql_execute('''
SELECT MAX(cc_row_id) + 1 FROM bayesdb_crosscat_subsample
WHERE generator_id = ?
''', (generator_id,)).next()[0]
for i, rowid in enumerate(rowids):
row_ids[index[rowid]] = next_row_id + i
assert all(row_id is not None for row_id in row_ids)
return row_ids, X_L_list, X_D_list

def name(self):
return 'crosscat'

Expand Down Expand Up @@ -288,12 +394,17 @@ def register(self, bdb):
'theta_json': theta_json,
})
version = 2
if version != 2:
if version == 2:
for stmt in crosscat_schema_2to3.split(';'):
bdb.sql_execute(stmt)
version = 3
if version != 3:
raise BQLError(bdb, 'Crosscat already installed'
' with unknown schema version: %d' % (version,))

def create_generator(self, bdb, table, schema, instantiate):
do_guess = False
do_subsample = self._subsample
columns = []
for directive in schema:
if isinstance(directive, list) and \
Expand All @@ -303,6 +414,21 @@ def create_generator(self, bdb, table, schema, instantiate):
directive[1] == ['*']:
do_guess = True
continue
if isinstance(directive, list) and \
len(directive) == 2 and \
isinstance(directive[0], (str, unicode)) and \
casefold(directive[0]) == 'subsample' and \
isinstance(directive[1], list) and \
len(directive[1]) == 1:
if isinstance(directive[1][0], (str, unicode)) and \
casefold(directive[1][0]) == 'off':
do_subsample = False
elif isinstance(directive[1][0], int):
do_subsample = directive[1][0]
else:
raise BQLError(bdb, 'Invalid subsampling: %s' %
(repr(directive[1][0]),))
continue
if isinstance(directive, list) and \
len(directive) == 2 and \
isinstance(directive[0], (str, unicode)) and \
Expand Down Expand Up @@ -379,6 +505,30 @@ def create_generator(self, bdb, table, schema, instantiate):
'value': codemap[code],
})

# If necessary, choose a subsample.
if do_subsample:
qt = sqlite3_quote_name(table)
cursor = bdb.sql_execute('SELECT COUNT(*) FROM %s' % (qt,))
nrows = cursor.next()[0]
if do_subsample < nrows:
cursor = bdb.sql_execute('''
SELECT _rowid_ FROM %s ORDER BY _rowid_ ASC LIMIT ?
''' % (qt,), (do_subsample,))
insert_subsampled_sql = '''
INSERT INTO bayesdb_crosscat_subsampled (generator_id)
VALUES (?)
'''
bdb.sql_execute(insert_subsampled_sql, (generator_id,))
insert_subsample_sql = '''
INSERT INTO bayesdb_crosscat_subsample
(generator_id, sql_rowid, cc_row_id)
VALUES (?, ?, ?)
'''
for i, row in enumerate(cursor):
rowid = row[0]
bdb.sql_execute(insert_subsample_sql,
(generator_id, rowid, i))

def drop_generator(self, bdb, generator_id):
with bdb.savepoint():
# Remove the metadata from the cache.
Expand Down Expand Up @@ -724,21 +874,31 @@ def column_value_probability(self, bdb, generator_id, modelno, colno,

def row_similarity(self, bdb, generator_id, modelno, rowid, target_rowid,
colnos):
X_L_list = self._crosscat_latent_state(bdb, generator_id, modelno)
X_D_list = self._crosscat_latent_data(bdb, generator_id, modelno)
[given_row_id, target_row_id], X_L_list, X_D_list = \
self._crosscat_get_rows(bdb, generator_id, [rowid, target_rowid],
X_L_list, X_D_list)
return self._crosscat.similarity(
M_c=self._crosscat_metadata(bdb, generator_id),
X_L_list=self._crosscat_latent_state(bdb, generator_id, modelno),
X_D_list=self._crosscat_latent_data(bdb, generator_id, modelno),
given_row_id=crosscat_row_id(rowid),
target_row_id=crosscat_row_id(target_rowid),
X_L_list=X_L_list,
X_D_list=X_D_list,
given_row_id=given_row_id,
target_row_id=target_row_id,
target_columns=[crosscat_cc_colno(bdb, generator_id, colno)
for colno in colnos],
)

def row_typicality(self, bdb, generator_id, modelno, rowid):
X_L_list = self._crosscat_latent_state(bdb, generator_id, modelno)
X_D_list = self._crosscat_latent_data(bdb, generator_id, modelno)
row_id, X_L_list, X_D_list = \
self._crosscat_get_row(bdb, generator_id, rowid, X_L_list,
X_D_list)
return self._crosscat.row_structural_typicality(
X_L_list=self._crosscat_latent_state(bdb, generator_id, modelno),
X_D_list=self._crosscat_latent_data(bdb, generator_id, modelno),
row_id=crosscat_row_id(rowid),
X_L_list=X_L_list,
X_D_list=X_D_list,
row_id=row_id,
)

def row_column_predictive_probability(self, bdb, generator_id, modelno,
Expand All @@ -764,12 +924,17 @@ def row_column_predictive_probability(self, bdb, generator_id, modelno,
return None
code = crosscat_value_to_code(bdb, generator_id, M_c, colno, value)
cc_colno = crosscat_cc_colno(bdb, generator_id, colno)
X_L_list = self._crosscat_latent_state(bdb, generator_id, modelno)
X_D_list = self._crosscat_latent_data(bdb, generator_id, modelno)
row_id, X_L_list, X_D_list = \
self._crosscat_get_row(bdb, generator_id, rowid, X_L_list,
X_D_list)
r = self._crosscat.simple_predictive_probability_multistate(
M_c=M_c,
X_L_list=self._crosscat_latent_state(bdb, generator_id, modelno),
X_D_list=self._crosscat_latent_data(bdb, generator_id, modelno),
X_L_list=X_L_list,
X_D_list=X_D_list,
Y=[],
Q=[(crosscat_row_id(rowid), cc_colno, code)],
Q=[(row_id, cc_colno, code)],
)
return math.exp(r)

Expand Down Expand Up @@ -801,12 +966,16 @@ def predict_confidence(self, bdb, generator_id, modelno, colno, rowid,
raise BQLError(bdb, 'More than one such row'
' in table %s for generator %s: %d' %
(repr(table_name), repr(generator), repr(rowid)))
row_id = crosscat_row_id(rowid)
X_L_list = self._crosscat_latent_state(bdb, generator_id, modelno)
X_D_list = self._crosscat_latent_data(bdb, generator_id, modelno)
row_id, X_L_list, X_D_list = \
self._crosscat_get_row(bdb, generator_id, rowid, X_L_list,
X_D_list)
cc_colno = crosscat_cc_colno(bdb, generator_id, colno)
code, confidence = self._crosscat.impute_and_confidence(
M_c=M_c,
X_L=self._crosscat_latent_state(bdb, generator_id, modelno),
X_D=self._crosscat_latent_data(bdb, generator_id, modelno),
X_L=X_L_list,
X_D=X_D_list,
Y=[(row_id,
crosscat_gen_colno(bdb, generator_id, cc_colno_),
crosscat_value_to_code(bdb, generator_id, M_c,
Expand Down Expand Up @@ -835,7 +1004,7 @@ def simulate(self, bdb, generator_id, modelno, constraints, colnos,
assert len(row) == 1
max_rowid = row[0]
fake_rowid = max_rowid + 1
fake_row_id = crosscat_row_id(fake_rowid)
fake_row_id = fake_rowid - 1
# XXX Why special-case empty constraints?
Y = None
if constraints is not None:
Expand Down Expand Up @@ -997,9 +1166,6 @@ def create_metadata_categorical(bdb, generator_id, colno):
'numerical': create_metadata_numerical,
}

def crosscat_row_id(rowid):
return rowid - 1

def crosscat_value_to_code(bdb, generator_id, M_c, colno, value):
stattype = core.bayesdb_generator_column_stattype(bdb, generator_id, colno)
if stattype == 'categorical':
Expand Down
12 changes: 8 additions & 4 deletions tests/test_bql.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,14 +905,16 @@ def trace(string, _bindings):
'SELECT bql_row_similarity(1, NULL, _rowid_,'
' (SELECT _rowid_ FROM "t" WHERE ("rowid" = 1)), 0) FROM "t"',
'SELECT metamodel FROM bayesdb_generator WHERE id = ?',
'SELECT metadata_json FROM bayesdb_crosscat_metadata'
' WHERE generator_id = ?',
'SELECT modelno FROM bayesdb_crosscat_theta'
' WHERE generator_id = ?',
'SELECT theta_json FROM bayesdb_crosscat_theta'
' WHERE generator_id = ? AND modelno = ?',
'SELECT modelno FROM bayesdb_crosscat_theta'
' WHERE generator_id = ?',
'SELECT COUNT(*) FROM bayesdb_crosscat_subsampled'
' WHERE generator_id = ?',
'SELECT metadata_json FROM bayesdb_crosscat_metadata'
' WHERE generator_id = ?',
'SELECT cc_colno FROM bayesdb_crosscat_column'
' WHERE generator_id = ? AND colno = ?',
]
Expand Down Expand Up @@ -949,14 +951,16 @@ def trace(string, _bindings):
'SELECT bql_row_similarity(1, NULL, _rowid_,'
' (SELECT _rowid_ FROM "t" WHERE ("rowid" = 1)), 0) FROM "t"',
'SELECT metamodel FROM bayesdb_generator WHERE id = ?',
'SELECT metadata_json FROM bayesdb_crosscat_metadata'
' WHERE generator_id = ?',
'SELECT modelno FROM bayesdb_crosscat_theta'
' WHERE generator_id = ?',
'SELECT theta_json FROM bayesdb_crosscat_theta'
' WHERE generator_id = ? AND modelno = ?',
'SELECT modelno FROM bayesdb_crosscat_theta'
' WHERE generator_id = ?',
'SELECT COUNT(*) FROM bayesdb_crosscat_subsampled'
' WHERE generator_id = ?',
'SELECT metadata_json FROM bayesdb_crosscat_metadata'
' WHERE generator_id = ?',
'SELECT cc_colno FROM bayesdb_crosscat_column'
' WHERE generator_id = ? AND colno = ?',
]
Expand Down

0 comments on commit 0780f7c

Please sign in to comment.