# Categorical Variables in Snorkel

This is a short tutorial on how to use categorical variables (i.e. more values than binary) in Snorkel.  We'll use a completely toy scenario with three sentences and two LFs just to demonstrate the mechanics. Please see the main tutorial for a more comprehensive intro!

We'll **highlight in bold all parts focusing on the categorical aspect.**

### Notes on Current Categorical Support:
* The `Viewer` works in the categorical setting, _but labeling `Candidate`s in the `Viewer` does not._
    - Instead can import test / dev set labels from e.g. BRAT
* The `LogisticRegression` and `SparseLogisticRegression` end models have been extended to the categorical setting, but other end models in `contrib` may not have been
    - _Note: It's simple to make this change, so feel free to post an issue with requests for other end models!_

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np

import os
os.environ['SNORKELDB'] = "postgresql:///spouse"

from snorkel import SnorkelSession
session = SnorkelSession()

print(os.environ['SNORKELHOME'] )

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/Users/xinq/Desktop/snorkel


## Step 1: Preprocessing the data

In [3]:
from snorkel.parser import TSVDocPreprocessor
from snorkel.parser.spacy_parser import Spacy
from snorkel.parser import CorpusParser


doc_preprocessor = TSVDocPreprocessor('data/categorical_example.tsv') 
corpus_parser = CorpusParser(parser=Spacy()) # Spouse example uses parser=Spacy())
corpus_parser.apply(doc_preprocessor)

Clearing existing...
Running UDF...


## Step 2: Defining candidates

We'll define candidate relations between person mentions **that now can take on one of three values:**
```python
['Married', 'Employs', 'Others']
```
Note the importance of including a value for "not a relation of interest"- here we've used `False`, but any value could do.
Also note that `None` is a protected value -- denoting a labeling function abstaining -- so this cannot be used as a value.

In [4]:
from snorkel.models import candidate_subclass
Relationship = candidate_subclass('Relationship', ['person1', 'person2'], values=['Married', 'Employs', 'Others']) # False])

Now we extract candidates the same as in the Intro Tutorial (simplified here slightly):

In [5]:
from snorkel.candidates import Ngrams, CandidateExtractor
from snorkel.matchers import PersonMatcher
from snorkel.models import Sentence

# Define a Person-Person candidate extractor
ngrams = Ngrams(n_max=3)
person_matcher = PersonMatcher(longest_match_only=True)
cand_extractor = CandidateExtractor(
    Relationship, 
    [ngrams, ngrams],
    [person_matcher, person_matcher],
    symmetric_relations=False 
)

# Apply to all (three) of the sentences for this simple example
sents = session.query(Sentence).all()

# Run the candidate extractor
%time cand_extractor.apply(sents, split=0)


print(sents)

Clearing existing...
Running UDF...

CPU times: user 64.2 ms, sys: 7.08 ms, total: 71.3 ms
Wall time: 106 ms
[Sentence(Document d01,0,b'John is married to Susan.'), Sentence(Document d01,1,b'Susan also employs John at her company.'), Sentence(Document d01,2,b'John enjoys being married to his boss Susan.'), Sentence(Document d01,3,b'John is a friend of Mary.'), Sentence(Document d01,4,b'Unlike traditional dimension reducing methods of most current automatic detection, the proposed approach adopts biclustering to perform an unsupervised dimension reduction, which is more suitable to the characteristic of EEG signals.'), Sentence(Document d01,5,b'To verify the performance of the presented approach, experiments have been carried out in the epileptic EEG data.'), Sentence(Document d01,6,b'The average sensitivity, specificity and recognition accuracy obtained by our method are 96.67%, 100.00% and 98.00%.'), Sentence(Document d01,7,b'The study might be meaningful for improving the diagnostic 

In [6]:


sentid=-5


print(" ".join(sents[sentid].__dict__['lemmas']))
print()

print(" ".join(sents[sentid].__dict__['pos_tags']))
print()

print(sents[sentid].__dict__['dep_parents'])

print(sents[sentid])


the study may be meaningful for improve the diagnostic accuracy of epileptic disease , relieve the workload of doctor and reduce the medical cost .

DT NN MD VB JJ IN VBG DT JJ NN IN JJ NN , VBG DT NN IN NNS CC VBG DT JJ NN .

[2, 4, 4, 0, 4, 5, 6, 10, 10, 7, 10, 13, 11, 13, 7, 17, 15, 17, 18, 17, 17, 24, 24, 21, 4]
Sentence(Document d01,7,b'The study might be meaningful for improving the diagnostic accuracy of epileptic disease, relieving the workload of doctors and reducing the medical cost.')


In [7]:
train_cands = session.query(Relationship).filter(Relationship.split == 0).all()
print("Number of candidates:", len(train_cands))
print(train_cands)

print(train_cands[0].get_contexts()[0].sentence.__dict__['pos_tags'])
# segment_cue.sentence.__dict__)

Number of candidates: 4
[Relationship(Span("b'John'", sentence=2, chars=[0,3], words=[0,0]), Span("b'Susan'", sentence=2, chars=[19,23], words=[4,4])), Relationship(Span("b'Susan'", sentence=3, chars=[0,4], words=[0,0]), Span("b'John'", sentence=3, chars=[19,22], words=[3,3])), Relationship(Span("b'John'", sentence=4, chars=[0,3], words=[0,0]), Span("b'Susan'", sentence=4, chars=[38,42], words=[7,7])), Relationship(Span("b'John'", sentence=5, chars=[0,3], words=[0,0]), Span("b'Mary'", sentence=5, chars=[20,23], words=[5,5]))]
['NNP', 'VBZ', 'JJ', 'IN', 'NNP', '.']


In [8]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    sv = SentenceNgramViewer(train_cands, session)
else:
    sv = None

<IPython.core.display.Javascript object>

In [9]:
sv

SentenceNgramViewer(cids=[[[0], [1], [2]], [[3]]], html='<head>\n<style>\nspan.candidate {\n    background-col…

## Step 3: Writing Labeling Functions

**The _categorical_ labeling functions (LFs) we now write can output the following values:**

* Abstain: `None` OR 0
* Categorical values: The literal values in `Relationship.values` OR their integer indices.

We'll write two simple LFs to illustrate.

*Tip: we can get a random candidate (see below), or the example highlighted in the viewer above via `sv.get_selected()`, and then use this to test as we write the LFs!*

In [10]:
print(Relationship.values)
import re
from snorkel.lf_helpers import get_between_tokens

# Getting an example candidate from the Viewer
c = train_cands[0]

# Traversing the context hierarchy...
print(c.get_contexts()[0].get_parent().text)

# Using a helper function
list(get_between_tokens(c))





['Married', 'Employs', 'Others']


  from ._conv import register_converters as _register_converters


John is married to Susan.


['is', 'married', 'to']

In [11]:
def LF_married(c):
    return 'Married' if 'married' in get_between_tokens(c) else None

WORKPLACE_RGX = r'employ|boss|company'
def LF_workplace(c):
    sent = c.get_contexts()[0].get_parent()
    matches = re.search(WORKPLACE_RGX, sent.text)
    return 'Employs' if matches else None
FRIENDSHIP_RGX = r'friend'
def LF_friendship(c):
    sent = c.get_contexts()[0].get_parent()
    matches = re.search(FRIENDSHIP_RGX, sent.text)
    return 'Others' if matches else None

LFs = [
    LF_married,
    LF_workplace,
    LF_friendship
]

Now we apply the LFs to the candidates to produce our label matrix $L$:

In [15]:
from snorkel.annotations import LabelAnnotator

labeler = LabelAnnotator(lfs=LFs)
%time L_train = labeler.apply(split=0)
L_train

cids_count 4
key_group 0
Clearing existing...
Running UDF...

CPU times: user 74.2 ms, sys: 8.48 ms, total: 82.7 ms
Wall time: 130 ms


<4x3 sparse matrix of type '<class 'numpy.int64'>'
	with 5 stored elements in Compressed Sparse Row format>

In [16]:
# snorkel_conn_string = os.environ['SNORKELDB'] if 'SNORKELDB' in os.environ and os.environ['SNORKELDB'] != '' \
#     else 'sqlite:///' + os.getcwd() + os.sep + 'snorkel.db'

print(os.environ['SNORKELDB'])

postgresql:///spouse


In [92]:
L_train.todense()

matrix([[1, 0, 0],
        [0, 2, 0],
        [1, 2, 0],
        [0, 0, 3]], dtype=int64)

## Step 4: Training the Generative Model

In [104]:
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel(lf_propensity=True)

# Note: We pass cardinality explicitly here to be safe
# Can usually be inferred, except we have no labels with value=3
gen_model.train(L_train, cardinality=3,epochs=10,verbose=False)

FACTOR 0: STARTED BURN-IN...
FACTOR 0: DONE WITH BURN-IN
FACTOR 0: STARTED LEARNING
FACTOR 0: EPOCH #0
Current stepsize = 0.0001
Learning epoch took 0.000 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  0.0

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  0.0

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  0.0

    weightId: 6
        isFixed: False
        weight:  0.0

    weightId: 7
        isFixed: False
        weight:  0.0

    weightId: 8
        isFixed: False
        weight:  0.0

FACTOR 0: EPOCH #1
Current stepsize = 0.0001
Learning epoch took 0.000 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.00019998700054998293

    weightId: 2
        isFixed: True
        weight:  1.0


In [105]:
train_marginals = gen_model.marginals(L_train)


assert np.all(train_marginals.sum(axis=1) - np.ones(train_marginals.shape[0]) < 1e-10)

print(train_marginals.shape)
print(train_marginals)





(4, 3)
[[0.78624771 0.10687615 0.10687615]
 [0.10699376 0.78601248 0.10699376]
 [0.46848637 0.46783138 0.06368224]
 [0.10690966 0.10690966 0.78618067]]


Next, we can save the training marginals:

In [42]:
from snorkel.annotations import save_marginals, load_marginals

save_marginals(session, L_train, train_marginals)

Saved 4 marginals


And then reload (e.g. in another notebook):

In [43]:
load_marginals(session, L_train)

array([[0.7841923 , 0.10790385, 0.10790385],
       [0.10771717, 0.78456567, 0.10771717],
       [0.46733148, 0.46836431, 0.06430421],
       [0.10776812, 0.10776812, 0.78446377]])

## Step 5: Training the End Model

Now we train an LSTM--note this is just to demonstrate the mechanics... since we only have three examples, don't expect anything spectacular!

In [44]:
from snorkel.learning.disc_models.rnn import reRNN

train_kwargs = {
    'lr':         0.01,
    'dim':        50,
    'n_epochs':   10,
    'dropout':    0.25,
    'print_freq': 1,
    'max_sentence_length': 100
}

lstm = reRNN(seed=1701, n_threads=None, cardinality=Relationship.cardinality)
lstm.train(train_cands, train_marginals, **train_kwargs)

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


[reRNN] Training model
[reRNN] n_train=4  #epochs=10  batch size=4
[reRNN] Epoch 0 (0.38s)	Average loss=1.096820
[reRNN] Epoch 1 (0.39s)	Average loss=1.065301
[reRNN] Epoch 2 (0.40s)	Average loss=1.030042
[reRNN] Epoch 3 (0.41s)	Average loss=0.970748
[reRNN] Epoch 4 (0.42s)	Average loss=0.886815
[reRNN] Epoch 5 (0.43s)	Average loss=0.812141
[reRNN] Epoch 6 (0.44s)	Average loss=0.768691
[reRNN] Epoch 7 (0.45s)	Average loss=0.740026
[reRNN] Epoch 8 (0.47s)	Average loss=0.771290
[reRNN] Epoch 9 (0.48s)	Average loss=0.764788
[reRNN] Training done (0.48s)


In [47]:
train_labels = [1, 2, 2,3]
correct, incorrect = lstm.error_analysis(session, train_cands, train_labels)
print(correct,incorrect)
# print(gen_model.error_analysis(session, L_train, train_labels))

Accuracy: 0.75
{Relationship(Span("b'Susan'", sentence=3, chars=[0,4], words=[0,0]), Span("b'John'", sentence=3, chars=[19,22], words=[3,3])), Relationship(Span("b'John'", sentence=2, chars=[0,3], words=[0,0]), Span("b'Susan'", sentence=2, chars=[19,23], words=[4,4])), Relationship(Span("b'John'", sentence=5, chars=[0,3], words=[0,0]), Span("b'Mary'", sentence=5, chars=[20,23], words=[5,5]))} {Relationship(Span("b'John'", sentence=4, chars=[0,3], words=[0,0]), Span("b'Susan'", sentence=4, chars=[38,42], words=[7,7]))}


In [48]:
print("Accuracy:", lstm.score(train_cands, train_labels))

Accuracy: 0.75


In [49]:
test_marginals = lstm.marginals(train_cands)
test_marginals

array([[0.91320086, 0.03214792, 0.05465119],
       [0.05433137, 0.88214684, 0.0635217 ],
       [0.882513  , 0.09651681, 0.02097021],
       [0.07051656, 0.06210546, 0.867378  ]], dtype=float32)