In [None]:
%load_ext autoreload
%autoreload 2

## Setup

In [None]:
config = {
    'domain': 'spouse',
    'postgres': False,
    'parallelism': 1,
    'db_name': 'babble_spouse_demo',
    'debug': False,
    'babbler_candidate_split': 1,
    'babbler_label_split': 1,
    'disc_model_search_space': 1,
    'gen_model_search_space': 1,
    'supervision': 'majority',
}

In [None]:
# Get DB connection string and add to globals
# NOTE: $SNORKELDB must be set before any snorkel imports
import os

default_db_name = 'babble_' + config['domain'] + ('_debug' if config.get('debug', False) else '')
DB_NAME = config.get('db_name', default_db_name)
if 'postgres' in config and config['postgres']:
    DB_TYPE = 'postgres'
else:
    DB_TYPE = 'sqlite'
    DB_NAME += '.db'
DB_ADDR = "localhost:{0}".format(config['db_port']) if 'db_port' in config else ""
os.environ['SNORKELDB'] = '{0}://{1}/{2}'.format(DB_TYPE, DB_ADDR, DB_NAME)
print("$SNORKELDB = {0}".format(os.environ['SNORKELDB']))

In [None]:
from snorkel import SnorkelSession
session = SnorkelSession()

# Resolve config conflicts (nb_config > local_config > global_config)
from snorkel.contrib.babble.pipelines import merge_configs
config = merge_configs(config)

if config['debug']:
    print("NOTE: --debug=True: modifying parameters...")
    config['max_docs'] = 100
    config['gen_model_search_space'] = 2
    config['disc_model_search_space'] = 2
    config['gen_params_default']['epochs'] = 25
    config['disc_params_default']['n_epochs'] = 5

In [None]:
from snorkel.models import candidate_subclass
from tutorials.babble.spouse import SpousePipeline

Spouse = candidate_subclass('Spouse', ['person1', 'person2'])
candidate_class = Spouse
pipe = SpousePipeline(session, Spouse, config)

## Parse, Extract, Load

In [None]:
# %time pipe.parse()

In [None]:
# %time pipe.extract()

In [None]:
# %time pipe.load_gold()

## Now the real work begins...

In [None]:
from snorkel.contrib.babble import BabbleStream
bs = BabbleStream(session, candidate_class=Spouse, balanced=True, seed=123)

In [None]:
from tutorials.babble.spouse.spouse_examples import get_explanations, get_user_lists

candidates = session.query(Spouse).filter(Spouse.split == 0).all()
spouse_explanations = get_explanations()
spouse_user_lists = get_user_lists()

In [None]:
bs.preload(explanations=spouse_explanations, user_lists=spouse_user_lists)

In [None]:
c = bs.next()

In [None]:
from snorkel.viewer import SentenceNgramViewer
sv = SentenceNgramViewer([c], session, n_per_page=1, height=200)
sv

In [None]:
from snorkel.contrib.babble import Explanation
label = True
condition = "married is within two words to the left of arg 2"
explanation = Explanation(condition, label, candidate=c)
explanation

In [None]:
%time parse_list, conf_matrix_list, stats_list = bs.apply(explanation)

In [None]:
print(stats_list[0].accuracy)
print(stats_list[0].class_coverage)

In [None]:
from snorkel.viewer import SentenceNgramViewer
error_set = conf_matrix_list[0].correct
sv = SentenceNgramViewer(list(error_set)[:10], session, n_per_page=3, height=300)
sv

In [None]:
global_coverage = bs.get_global_stats()
print(global_coverage)

In [None]:
bs.commit([]) # Permanently adds the parses corresponding to these idxs

Confirm that after committing, global coverage goes up.

In [None]:
global_coverage = bs.get_global_stats()
print(global_coverage)

In [None]:
L_train = bs.get_label_matrix()
L_train

### Add another explanation

In [None]:
from snorkel.contrib.babble import Explanation
label = False
condition = "'where' is within two words to the right of arg 1"
explanation = Explanation(condition, label, candidate=c, name='')

In [None]:
%time parse_list, conf_matrix_list, stats_list = bs.apply(explanation)

In [None]:
print(stats_list[0].accuracy)
print(stats_list[0].class_coverage)

In [None]:
bs.commit()

In [None]:
parse = parse_list[0]
parse.semantics

In [None]:
bs.semparser.grammar.translate(parse.semantics)

In [None]:
pipe.lfs = [parse.function for parse in bs.parses]
pipe.label()

In [None]:
# %time pipe.supervise()

In [None]:
# %time pipe.classify()