In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os

# TO USE A DATABASE OTHER THAN SQLITE, USE THIS LINE
# Note that this is necessary for parallel execution amongst other things...
os.environ['SNORKELDB'] = 'postgres:///babble_test_bike'

from snorkel import SnorkelSession
session = SnorkelSession()

In [2]:
import numpy as np

# anns_folder = '/dfs/scratch0/paroma/coco/annotations/'
anns_folder = os.environ['SNORKELHOME'] + '/tutorials/babble/bike/data/'
train_path = anns_folder + 'train_anns.npy'
val_path = anns_folder + 'val_anns.npy'

train_anns = np.load(train_path).tolist()
val_anns = np.load(val_path).tolist()

In [3]:
from snorkel.models import candidate_subclass

Biker = candidate_subclass('Biker', ['person', 'bike'])

In [4]:
from snorkel.parser import ImageCorpusExtractor, CocoPreprocessor

corpus_extractor = ImageCorpusExtractor(candidate_class=Biker)

coco_preprocessor = CocoPreprocessor(train_path, source=0)
%time corpus_extractor.apply(coco_preprocessor)

coco_preprocessor = CocoPreprocessor(val_path, source=1)
%time corpus_extractor.apply(coco_preprocessor, clear=False)

for split in [0, 1]:
    num_candidates = session.query(Biker).filter(Biker.split == split).count()
    print("Split {} candidates: {}".format(split, num_candidates))

Clearing existing...
Running UDF...
CPU times: user 3.98 s, sys: 228 ms, total: 4.21 s
Wall time: 7.25 s
Running UDF...
CPU times: user 1.8 s, sys: 103 ms, total: 1.9 s
Wall time: 3.35 s
Split 0 candidates: 2406
Split 1 candidates: 1037


In [5]:
labels_by_candidate = np.load(anns_folder + 'labels_by_candidate.npy').tolist()

In [6]:
from snorkel.models import StableLabel
from snorkel.db_helpers import reload_annotator_labels

candidate_class = Biker
annotator_name = 'gold'

for candidate_hash, label in labels_by_candidate.items():
    set_name, image_idx, bbox1_idx, bbox2_idx = candidate_hash.split(':')
    source = {'train': 0, 'val': 1}[set_name]
    stable_id_1 = "{}:{}::bbox:{}".format(source, image_idx, bbox1_idx)
    stable_id_2 = "{}:{}::bbox:{}".format(source, image_idx, bbox2_idx)
    context_stable_ids = "~~".join([stable_id_1, stable_id_2])
    query = session.query(StableLabel).filter(StableLabel.context_stable_ids == context_stable_ids)
    query = query.filter(StableLabel.annotator_name == annotator_name)
    label = 1 if label else -1
    if query.count() == 0:
        session.add(StableLabel(
            context_stable_ids=context_stable_ids,
            annotator_name=annotator_name,
            value=label))

session.commit()
reload_annotator_labels(session, candidate_class, annotator_name, split=1, filter_label_split=False)

AnnotatorLabels created: 906


In [7]:
stable_labels = session.query(StableLabel).filter(StableLabel.annotator_name == annotator_name).all()
len(stable_labels)

906

In [8]:
candidates = num_candidates = session.query(Biker).filter(Biker.split == 1).all()
print(candidates[0])

Biker(Bbox(val:0:3:person:(249, 306, 439, 456)), Bbox(val:0:1:bike:(304, 359, 455, 533)))


In [9]:
from snorkel.annotations import load_gold_labels

L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
L_gold_dev

<1037x1 sparse matrix of type '<type 'numpy.int64'>'
	with 906 stored elements in Compressed Sparse Row format>