In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import os

DOMAIN = 'drink'

db_dict = {
    'test':   'postgres:///babble_model_unittest',
    'spouse': 'postgres:///babble_model_spouse',
    'bike':   'postgres:///babble_model_bike',
    'drink':  'postgres:///babble_model_drink',
    'cdr':    'postgres:///babble_model_cdr',
}
os.environ['SNORKELDB'] = db_dict[DOMAIN]

In [3]:
from snorkel import SnorkelSession
session = SnorkelSession()

OperationalError: (psycopg2.OperationalError) FATAL:  role "paroma" does not exist


In [None]:
from snorkel.contrib.babble.pipelines import config

config = configuration
config['domain'] = DOMAIN
config['splits'] = [0,1]
config['babbler_split'] = 0
config['max_docs'] = None
config['parallelism'] = 1
config['traditional'] = False
config['majority_vote'] = False
config['verbose'] = True
config['display_marginals'] = True
config['display_accuracies'] = True
config['display_learned_accuracies'] = True

In [None]:
from snorkel.models import candidate_subclass
from tutorials.babble import MTurkHelper
from snorkel.contrib.babble import ExplanationIO

if DOMAIN == 'spouse':
    from tutorials.babble.spouse import SpousePipeline
    Spouse = candidate_subclass('Spouse', ['person1', 'person2'])
    candidate_class = Spouse
    
    expio = ExplanationIO()
    fpath = (os.environ['SNORKELHOME'] + 
        '/tutorials/babble/spouse/data/mturk_explanations_all.tsv')
    explanations = expio.read(fpath)
    
    sm = SpousePipeline(session, Spouse, config)
elif DOMAIN == 'bike':
    from tutorials.babble.bike import BikePipeline
    Biker = candidate_subclass('Biker', ['person', 'bike'])
    candidate_class = Biker
    
    helper = MTurkHelper()
    output_csv_path = (os.environ['SNORKELHOME'] + 
                       '/tutorials/babble/bike/data/VisualGenome_all_out.csv')
    explanations = helper.postprocess_visual(output_csv_path, set_name='train', verbose=False)
    
    sm = BikePipeline(session, Biker, config)
elif DOMAIN == 'drink':
    from tutorials.babble.drink import DrinkPipeline
    Drinker = candidate_subclass('Drinker', ['person', 'cup'])
    candidate_class = Drinker
    
    helper = MTurkHelper()
    output_csv_path = (os.environ['SNORKELHOME'] + 
                       '/tutorials/babble/drink/data/Reach_Explanation_out.csv')
    explanations = helper.postprocess_visual(output_csv_path, set_name='train', verbose=False)
    
    sm = DrinkPipeline(session, Drinker, config)   
else:
    raise Exception('Invalid domain: {}'.format(DOMAIN))
user_lists = {}
print("Total explanations: {}".format(len(explanations)))

In [None]:
anns_folder = os.environ['SNORKELHOME'] + '/tutorials/babble/drink/data/'
%time sm.parse(anns_folder)

In [None]:
%time sm.extract()

In [None]:
%time sm.load_gold(anns_folder)

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_train = load_gold_labels(session, annotator_name='gold', split=0)
L_gold_train

In [None]:
import numpy as np
np.shape(np.where(np.array(L_gold_train.todense()) == 1))[1]

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
L_gold_dev

In [None]:
import numpy as np
np.shape(np.where(np.array(L_gold_dev.todense()) == 1))[1]

In [None]:
candidates = session.query(candidate_class).filter(
    candidate_class.split == config['babbler_split']).all()
# candidates = session.query(candidate_class).filter(candidate_class.split == 0).all()
print(len(candidates))

In [None]:
from snorkel.contrib.babble import link_explanation_candidates

explanations = link_explanation_candidates(explanations, candidates)

In [None]:
%time sm.babble(explanations, user_lists=user_lists, config=config)

In [None]:
%time sm.label()
import time
time.sleep(10)

In [None]:
%time sm.supervise(config=config)
time.sleep(10)

In [None]:
L_dev = sm.labeler.load_matrix(session, split=1)
time.sleep(10)
L_dev

In [None]:
L_dev.lf_stats(session, L_gold_dev)
time.sleep(10)

In [None]:
tp, fp, tn, fn = sm.gen_model.error_analysis(session, L_dev, L_gold_dev, b=0.9)

In [None]:
# %time sm.classify()

In [None]:
# L_train = sm.labeler.load_matrix(session, split=0)
# L_train

In [None]:
# # from snorkel.learning import GenerativePipeline
# from snorkel.learning import RandomSearch, ListParameter, RangeParameter

# epochs_param    = ListParameter('epochs', [5, 10, 20])
# decay_param     = ListParameter('decay', [0.9, 0.95])
# step_size_param = RangeParameter('step_size', 1e-6, 1e-5, step=1, log_base=10)
# reg_param       = ListParameter('reg_param', [1e-6])

# searcher = RandomSearch(GenerativePipeline, 
#                         [step_size_param, decay_param, epochs_param, reg_param],
#                         L_train, n=5)

In [None]:
# %%time
# gen_model_best, run_stats = searcher.fit(L_dev, L_gold_dev)
# run_stats

In [None]:
# gen_model_best = GenerativePipeline(class_prior=False, lf_prior=False, 
#                                  lf_propensity=False, lf_class_propensity=False)

In [None]:
# gen_model_best.train(L_train, epochs=10, decay=0.95, step_size=1e-6, reg_param=1e-6)

In [None]:
# tp, fp, tn, fn = gen_model_best.error_analysis(session, L_dev, L_gold_dev, b=0.8)

In [None]:
# import matplotlib.pyplot as plt

# train_marginals = gen_model_best.marginals(L_train)
# plt.hist(train_marginals, bins=30)
# plt.show()