# Detecting spouse mentions in sentences

In this tutorial, we will see how Snorkel can be used for Information Extraction. We will walk through an example text classification task for information extraction, where we use labeling functions involving keywords and distant supervision.
### Classification Task
<img src="imgs/sentence.jpg" width="700px;" onerror="this.onerror=null; this.src='/doks-theme/assets/images/sentence.jpg';" align="center" style="display: block; margin-left: auto; margin-right: auto;">

We want to classify each __candidate__ or pair of people mentioned in a sentence, as being married at some point or not.

In the above example, our candidate represents the possible relation `(Barack Obama, Michelle Obama)`. As readers, we know this mention is true due to external knowledge and the keyword of `wedding` occuring later in the sentence.
We begin with some basic setup and data downloading.


## Preamble

In [1]:
# Importations.
import os
import pandas as pd
import numpy as np
import pickle
import itertools
from sklearn.model_selection import train_test_split

from preprocessors import get_person_text
from preprocessors import get_left_tokens, get_person_last_names
from preprocessors import last_name

from snorkel.preprocess import preprocessor
from snorkel.labeling import labeling_function
from snorkel.labeling import PandasLFApplier
from snorkel.labeling import LFAnalysis
from snorkel.labeling.model import LabelModel
from snorkel.labeling.model import MajorityLabelVoter

from utils import load_data


%matplotlib inline

if os.path.basename(os.getcwd()) == "snorkel-tutorials":
    os.chdir("spouse")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Import data.
((df_val, Y_val), df_train, (df_test, Y_test)) = load_data()

In [3]:
Y_val

array([0, 0, 1, ..., 0, 0, 1])

In [4]:
# Take a look at the data.
print("Total training instances:", len(df_train))
print("Total development instances:", len(df_val))
print("Total test instances:", len(df_test))

Total training instances: 22254
Total development instances: 2811
Total test instances: 2701


**Input Data:** `df_val`, `df_train`, and `df_test` are `Pandas DataFrame` objects, where each row represents a particular __candidate__. For our problem, a candidate consists of a sentence, and two people mentioned in the sentence. The DataFrames contain the fields `sentence`, which refers to the sentence of the candidate, `tokens`, the tokenized form of the sentence, and `person1_word_idx` and `person2_word_idx`, which represent `[start, end]` indices in the tokens at which the first and second person's name appear, respectively.

We also have certain **preprocessed fields**, that we discuss a few cells below.

In [5]:
# Don't truncate text fields in the display
pd.set_option("display.max_colwidth", 0)

df_val.head()

Unnamed: 0,person1_word_idx,person2_word_idx,sentence,tokens,person1_right_tokens,person2_right_tokens,between_tokens
0,"(1, 1)","(22, 24)","The Richards are half-sisters to Kathy Hilton, the mother of socialite Paris Hilton and spouse of luxury hotel magnate Richard Howard Hilton.","[The, Richards, are, half, -, sisters, to, Kathy, Hilton, ,, the, mother, of, socialite, Paris, Hilton, and, spouse, of, luxury, hotel, magnate, Richard, Howard, Hilton, ., ]","[are, half, -, sisters]","[., ]","[are, half, -, sisters, to, Kathy, Hilton, ,, the, mother, of, socialite, Paris, Hilton, and, spouse, of, luxury, hotel, magnate]"
1,"(1, 1)","(7, 8)","The Richards are half-sisters to Kathy Hilton, the mother of socialite Paris Hilton and spouse of luxury hotel magnate Richard Howard Hilton.","[The, Richards, are, half, -, sisters, to, Kathy, Hilton, ,, the, mother, of, socialite, Paris, Hilton, and, spouse, of, luxury, hotel, magnate, Richard, Howard, Hilton, ., ]","[are, half, -, sisters]","[,, the, mother, of]","[are, half, -, sisters, to]"
2,"(7, 8)","(22, 24)","The Richards are half-sisters to Kathy Hilton, the mother of socialite Paris Hilton and spouse of luxury hotel magnate Richard Howard Hilton.","[The, Richards, are, half, -, sisters, to, Kathy, Hilton, ,, the, mother, of, socialite, Paris, Hilton, and, spouse, of, luxury, hotel, magnate, Richard, Howard, Hilton, ., ]","[,, the, mother, of]","[., ]","[,, the, mother, of, socialite, Paris, Hilton, and, spouse, of, luxury, hotel, magnate]"
3,"(6, 6)","(20, 21)","Prior to both his guests, Colbert's monologue - parts of which he did sitting down - ripped into Donald Trump and his oft-mocked policy of building a wall at the US-Mexico border and not eating Oreos anymore.","[Prior, to, both, his, guests, ,, Colbert, s, monologue, -, parts, of, which, he, did, sitting, down, -, ripped, into, Donald, Trump, and, his, oft, -, mocked, policy, of, building, a, wall, at, the, US, -, Mexico, border, and, not, eating, Oreos, anymore, ., ]","[s, monologue, -, parts]","[and, his, oft, -]","[s, monologue, -, parts, of, which, he, did, sitting, down, -, ripped, into]"
4,"(2, 2)","(4, 5)","People reported Williams and Ven Veen tied the knot Saturday at Brush Creek Ranch in Saratoga, Wyoming, in front of about 200 guests.","[People, reported, Williams, and, Ven, Veen, tied, the, knot, Saturday, at, Brush, Creek, Ranch, in, Saratoga, ,, Wyoming, ,, in, front, of, about, 200, guests, .]","[and, Ven, Veen, tied]","[tied, the, knot, Saturday]",[and]


Let's look at a candidate in the development set:

In [6]:
candidate = df_val.loc[2]
person_names = get_person_text(candidate).person_names

print("Sentence: ", candidate["sentence"])
print("Person 1: ", person_names[0])
print("Person 2: ", person_names[1])

Sentence:  The Richards are half-sisters to Kathy Hilton, the mother of socialite Paris Hilton and spouse of luxury hotel magnate Richard Howard Hilton.   
Person 1:  Kathy Hilton
Person 2:  Richard Howard Hilton


In [7]:
# Construct labeled train/val/test splits by subsetting labeled instances.

# Concatenate all labeled instances and drop duplicates.
df_test["Label"] = Y_test
df_val["Label"] = Y_val
df_labeled = pd.concat([df_test, df_val])
#duplicated_rows = df_labeled.astype(str).duplicated()
#print(duplicated_rows.value_counts())
#df_labeled = df_labeled.iloc[df_labeled.astype(str).drop_duplicates().index]

print(df_labeled.Label.value_counts())
print(df_labeled.Label.value_counts(normalize = True))
display(df_labeled.head(1))

print("Total labeled instances:", len(df_labeled))

# 5512 total labeled instances.
# 3858 = 0.7
# 1102 = 0.2
# 552  = 0.1

X_train, X_test, y_train, y_test = train_test_split(df_labeled, 
                                                    df_labeled["Label"], 
                                                    test_size = 0.3,
                                                    stratify = df_labeled["Label"],
                                                    random_state = 42)

print("Total training instances:", len(X_train))

0    5104
1    408 
Name: Label, dtype: int64
0    0.92598
1    0.07402
Name: Label, dtype: float64


Unnamed: 0,person1_word_idx,person2_word_idx,sentence,tokens,person1_right_tokens,person2_right_tokens,between_tokens,Label
0,"(0, 1)","(46, 47)","Mr Perpich had desperately searched for the 14-carat gold ring, which had cost $129 and was engraved with his initials The teenager then went to the local library to look at the school's 1974 yearbook and suspected the ring could belong to Mr Perpich.","[Mr, Perpich, had, desperately, searched, for, the, 14-carat, gold, ring, ,, which, had, cost, $, 129, and, was, engraved, with, his, initials, , The, teenager, then, went, to, the, local, library, to, look, at, the, school, s, 1974, yearbook, and, suspected, the, ring, could, belong, to, Mr, Perpich, ., ]","[had, desperately, searched, for]","[., ]","[had, desperately, searched, for, the, 14-carat, gold, ring, ,, which, had, cost, $, 129, and, was, engraved, with, his, initials, , The, teenager, then, went, to, the, local, library, to, look, at, the, school, s, 1974, yearbook, and, suspected, the, ring, could, belong, to]",0


Total labeled instances: 5512
Total training instances: 3858


In [8]:
# Total X_test = 1654
# We want 10% of total dataset for validation.
X_test, X_val, y_test, y_val = train_test_split(X_test, 
                                                y_test, 
                                                test_size = 0.3337364,
                                                stratify = y_test,
                                                random_state = 42)

print("Total testing instances:", len(X_test))
print("Total validation instances:", len(X_val))

Total testing instances: 1101
Total validation instances: 553


In [9]:
# View train, val, test dataframes.
print(X_train.info())
print(X_train.Label.value_counts())
print(X_train.Label.value_counts(normalize = True))
display(X_train.head(1))

print(X_test.info())
print(X_test.Label.value_counts())
print(X_test.Label.value_counts(normalize = True))
display(X_test.head(1))

print(X_val.info())
print(X_val.Label.value_counts())
print(X_val.Label.value_counts(normalize = True))
display(X_val.head(1))

<class 'pandas.core.frame.DataFrame'>
Int64Index: 3858 entries, 2057 to 285
Data columns (total 8 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   person1_word_idx      3858 non-null   object
 1   person2_word_idx      3858 non-null   object
 2   sentence              3858 non-null   object
 3   tokens                3858 non-null   object
 4   person1_right_tokens  3858 non-null   object
 5   person2_right_tokens  3858 non-null   object
 6   between_tokens        3858 non-null   object
 7   Label                 3858 non-null   int64 
dtypes: int64(1), object(7)
memory usage: 271.3+ KB
None
0    3572
1    286 
Name: Label, dtype: int64
0    0.925868
1    0.074132
Name: Label, dtype: float64


Unnamed: 0,person1_word_idx,person2_word_idx,sentence,tokens,person1_right_tokens,person2_right_tokens,between_tokens,Label
2057,"(1, 2)","(7, 8)","With Dellen Millard, 30, and Mark Smich, 27, in the prisoners’ box for preliminary proceedings ahead of their first-degree murder trial, Justice James Turnbull, the region’s senior judge, informed court that all hearings in the case would be postponed until next Monday.","[With, Dellen, Millard, ,, 30, ,, and, Mark, Smich, ,, 27, ,, in, the, prisoners’, box, for, preliminary, proceedings, ahead, of, their, first, -, degree, murder, trial, ,, Justice, James, Turnbull, ,, the, region, ’s, senior, judge, ,, informed, court, that, all, hearings, in, the, case, would, be, postponed, until, next, Monday, ., ]","[,, 30, ,, and]","[,, 27, ,, in]","[,, 30, ,, and]",0


<class 'pandas.core.frame.DataFrame'>
Int64Index: 1101 entries, 2363 to 391
Data columns (total 8 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   person1_word_idx      1101 non-null   object
 1   person2_word_idx      1101 non-null   object
 2   sentence              1101 non-null   object
 3   tokens                1101 non-null   object
 4   person1_right_tokens  1101 non-null   object
 5   person2_right_tokens  1101 non-null   object
 6   between_tokens        1101 non-null   object
 7   Label                 1101 non-null   int64 
dtypes: int64(1), object(7)
memory usage: 77.4+ KB
None
0    1020
1    81  
Name: Label, dtype: int64
0    0.926431
1    0.073569
Name: Label, dtype: float64


Unnamed: 0,person1_word_idx,person2_word_idx,sentence,tokens,person1_right_tokens,person2_right_tokens,between_tokens,Label
2363,"(2, 3)","(6, 7)","Coordinated: Sam Burgess and fiancé Phoebe Hooke, who have now moved to England, looked the picture of elegance in matching black and white ensembles Red rose: Ex Bachelor star Anna Henrich wore an off-the-shoulder red gown and a bold red lip to the annual sporting awards event The former Summer Bay stunner's frock featured a low dip at the top to show her toned stomach, and a high split at the bottom to show her lean legs.","[Coordinated, :, Sam, Burgess, and, fiancé, Phoebe, Hooke, ,, who, have, now, moved, to, England, ,, looked, the, picture, of, elegance, in, matching, black, and, white, ensembles, , Red, rose, :, Ex, Bachelor, star, Anna, Henrich, wore, an, off, -, the, -, shoulder, red, gown, and, a, bold, red, lip, to, the, annual, sporting, awards, event, , The, former, Summer, Bay, stunner, s, frock, featured, a, low, dip, at, the, top, to, show, her, toned, stomach, ,, and, a, high, split, at, the, bottom, to, show, her, lean, legs, ., ]","[and, fiancé, Phoebe, Hooke]","[,, who, have, now]","[and, fiancé]",1


<class 'pandas.core.frame.DataFrame'>
Int64Index: 553 entries, 222 to 2358
Data columns (total 8 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   person1_word_idx      553 non-null    object
 1   person2_word_idx      553 non-null    object
 2   sentence              553 non-null    object
 3   tokens                553 non-null    object
 4   person1_right_tokens  553 non-null    object
 5   person2_right_tokens  553 non-null    object
 6   between_tokens        553 non-null    object
 7   Label                 553 non-null    int64 
dtypes: int64(1), object(7)
memory usage: 38.9+ KB
None
0    512
1    41 
Name: Label, dtype: int64
0    0.925859
1    0.074141
Name: Label, dtype: float64


Unnamed: 0,person1_word_idx,person2_word_idx,sentence,tokens,person1_right_tokens,person2_right_tokens,between_tokens,Label
222,"(0, 0)","(14, 14)","Zac said of the Dawson's Creek alum: 'When I first saw Katie and her rapport on a red carpet, I saw that magic.","[Zac, said, of, the, Dawson, s, Creek, alum, :, , When, I, first, saw, Katie, and, her, rapport, on, a, red, carpet, ,, I, saw, that, magic, .]","[said, of, the, Dawson]","[and, her, rapport, on]","[said, of, the, Dawson, s, Creek, alum, :, , When, I, first, saw]",0


### Preprocessing the Data

In a real application, there is a lot of data preparation, parsing, and database loading that needs to be completed before we generate candidates and dive into writing labeling functions. Here we've pre-generated candidates in a pandas DataFrame object per split (train,val,test).

### Labeling Function Helpers

When writing labeling functions, there are several functions you will use over and over again. In the case of text relation extraction as with this task, common functions include those for fetching text between mentions of the two people in a candidate, examing word windows around person mentions, and so on. We will wrap these functions as `preprocessors`.

In [10]:
@preprocessor()
def get_text_between(cand):
    """
    Returns the text between the two person mentions in the sentence for a candidate
    """
    start = cand.person1_word_idx[1] + 1
    end = cand.person2_word_idx[0]
    cand.text_between = " ".join(cand.tokens[start:end])
    return cand

### Candidate PreProcessors

For the purposes of the tutorial, we have three fields (`between_tokens`, `person1_right_tokens`, `person2_right_tokens`) preprocessed in the data, which can be used when creating labeling functions. We also provide the following set of `preprocessor`s for this task in `preprocessors.py`, along with the fields these populate.
* `get_person_text(cand)`: `person_names`
* `get_person_lastnames(cand)`: `person_lastnames`
* `get_left_tokens(cand)`: `person1_left_tokens`, `person2_left_tokens`

In [11]:
# Macros for labels.
POSITIVE = 1
NEGATIVE = 0
ABSTAIN = -1

In [12]:
# Check for the `spouse` words appearing between the person mentions
spouses = {"spouse", "wife", "husband", "ex-wife", "ex-husband"}


@labeling_function(resources=dict(spouses=spouses))
def lf_husband_wife(x, spouses):
    return POSITIVE if len(spouses.intersection(set(x.between_tokens))) > 0 else ABSTAIN

In [13]:
# Check for the `spouse` words appearing to the left of the person mentions
@labeling_function(resources=dict(spouses=spouses), pre=[get_left_tokens])
def lf_husband_wife_left_window(x, spouses):
    if len(set(spouses).intersection(set(x.person1_left_tokens))) > 0:
        return POSITIVE
    elif len(set(spouses).intersection(set(x.person2_left_tokens))) > 0:
        return POSITIVE
    else:
        return ABSTAIN

In [14]:
# Check for the person mentions having the same last name
@labeling_function(pre=[get_person_last_names])
def lf_same_last_name(x):
    p1_ln, p2_ln = x.person_lastnames

    if p1_ln and p2_ln and p1_ln == p2_ln:
        return POSITIVE
    return ABSTAIN

In [15]:
# Check for the word `married` between person mentions
@labeling_function()
def lf_married(x):
    return POSITIVE if "married" in x.between_tokens else ABSTAIN

In [16]:
# Check for words that refer to `family` relationships between and to the left of the person mentions
family = {
    "father",
    "mother",
    "sister",
    "brother",
    "son",
    "daughter",
    "grandfather",
    "grandmother",
    "uncle",
    "aunt",
    "cousin",
}
family = family.union({f + "-in-law" for f in family})


@labeling_function(resources=dict(family=family))
def lf_familial_relationship(x, family):
    return NEGATIVE if len(family.intersection(set(x.between_tokens))) > 0 else ABSTAIN


@labeling_function(resources=dict(family=family), pre=[get_left_tokens])
def lf_family_left_window(x, family):
    if len(set(family).intersection(set(x.person1_left_tokens))) > 0:
        return NEGATIVE
    elif len(set(family).intersection(set(x.person2_left_tokens))) > 0:
        return NEGATIVE
    else:
        return ABSTAIN

In [17]:
# Check for `other` relationship words between person mentions
other = {"boyfriend", "girlfriend", "boss", "employee", "secretary", "co-worker"}


@labeling_function(resources=dict(other=other))
def lf_other_relationship(x, other):
    return NEGATIVE if len(other.intersection(set(x.between_tokens))) > 0 else ABSTAIN

### Distant Supervision Labeling Functions

In addition to using factories that encode pattern matching heuristics, we can also write labeling functions that _distantly supervise_ data points. Here, we'll load in a list of known spouse pairs and check to see if the pair of persons in a candidate matches one of these.

[**DBpedia**](http://wiki.dbpedia.org/): Our database of known spouses comes from DBpedia, which is a community-driven resource similar to Wikipedia but for curating structured data. We'll use a preprocessed snapshot as our knowledge base for all labeling function valelopment.

We can look at some of the example entries from DBPedia and use them in a simple distant supervision labeling function.

Make sure `dbpedia.pkl` is in the `spouse/data` directory.

In [18]:
with open("data/dbpedia.pkl", "rb") as f:
    known_spouses = pickle.load(f)

list(known_spouses)[0:5]

[('Joseph Edward Davies', 'Marjorie Merriweather Post'),
 ('Bert Jansch', 'Heather Jansch'),
 ('Diarmait mac Maíl na mBó', 'Donnchad mac Briain'),
 ('Alfred Lunt', 'Lynn Fontanne'),
 ('Clara Gooding McMillan', 'Thomas S. McMillan')]

In [19]:
@labeling_function(resources=dict(known_spouses=known_spouses), pre=[get_person_text])
def lf_distant_supervision(x, known_spouses):
    p1, p2 = x.person_names
    if (p1, p2) in known_spouses or (p2, p1) in known_spouses:
        return POSITIVE
    else:
        return ABSTAIN

In [20]:
# Last name pairs for known spouses
last_names = set(
    [
        (last_name(x), last_name(y))
        for x, y in known_spouses
        if last_name(x) and last_name(y)
    ]
)


@labeling_function(resources=dict(last_names=last_names), pre=[get_person_last_names])
def lf_distant_supervision_last_names(x, last_names):
    p1_ln, p2_ln = x.person_lastnames

    return (
        POSITIVE
        if (p1_ln != p2_ln)
        and ((p1_ln, p2_ln) in last_names or (p2_ln, p1_ln) in last_names)
        else ABSTAIN
    )

#### Apply Labeling Functions to the Data
We create a list of labeling functions and apply them to the data

In [21]:
lfs = [
    lf_husband_wife,
    lf_husband_wife_left_window,
    #lf_same_last_name,
    #lf_married,
    lf_familial_relationship,
    lf_family_left_window,
    #lf_other_relationship,
    #lf_distant_supervision,
    #lf_distant_supervision_last_names,
]
applier = PandasLFApplier(lfs)

In [22]:
# Compute labeling function matrix.
L_train = applier.apply(X_train)
L_val = applier.apply(X_val)
L_test = applier.apply(X_test)

100%|█████████████████████████████████████| 3858/3858 [00:02<00:00, 1436.11it/s]
100%|███████████████████████████████████████| 553/553 [00:00<00:00, 1324.71it/s]
100%|█████████████████████████████████████| 1101/1101 [00:00<00:00, 1567.98it/s]


In [23]:
# Get labels.
Y_train = np.array(y_train)
Y_val = np.array(y_val)
Y_test = np.array(y_test)

In [24]:
# Summarize performance on training set.
LFAnalysis(L_train, lfs).lf_summary(Y_train)

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
lf_husband_wife,0,[1],0.105236,0.039917,0.017626,162,244,0.399015
lf_husband_wife_left_window,1,[1],0.032141,0.026179,0.003888,58,66,0.467742
lf_familial_relationship,2,[0],0.113271,0.043287,0.018144,423,14,0.967963
lf_family_left_window,3,[0],0.038362,0.02929,0.004147,139,9,0.939189


In [25]:
# Summarize performance on val set.
LFAnalysis(L_val, lfs).lf_summary(Y_val)

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
lf_husband_wife,0,[1],0.079566,0.028933,0.012658,21,23,0.477273
lf_husband_wife_left_window,1,[1],0.023508,0.018083,0.001808,7,6,0.538462
lf_familial_relationship,2,[0],0.122966,0.036166,0.012658,64,4,0.941176
lf_family_left_window,3,[0],0.034358,0.023508,0.0,18,1,0.947368


In [26]:
# Summarize performance on test set.
LFAnalysis(L_test, lfs).lf_summary(Y_test)

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
lf_husband_wife,0,[1],0.098093,0.035422,0.019982,40,68,0.37037
lf_husband_wife_left_window,1,[1],0.029064,0.021798,0.006358,11,21,0.34375
lf_familial_relationship,2,[0],0.102634,0.038147,0.019982,101,12,0.893805
lf_family_left_window,3,[0],0.035422,0.023615,0.00545,37,2,0.948718


### Training the Label Model

Now, we'll train a model of the LFs to estimate their weights and combine their outputs. Once the model is trained, we can combine the outputs of the LFs into a single, noise-aware training label set for our extractor.

In [27]:
label_model = LabelModel(cardinality = 2, verbose = True)
label_model.fit(L_train, 
                n_epochs = 300, 
                log_freq = 10, 
                seed = 12345)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|                                                | 0/300 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.006]
INFO:root:[10 epochs]: TRAIN:[loss=0.003]
INFO:root:[20 epochs]: TRAIN:[loss=0.002]
INFO:root:[30 epochs]: TRAIN:[loss=0.002]
INFO:root:[40 epochs]: TRAIN:[loss=0.002]
INFO:root:[50 epochs]: TRAIN:[loss=0.002]
INFO:root:[60 epochs]: TRAIN:[loss=0.002]
INFO:root:[70 epochs]: TRAIN:[loss=0.002]
INFO:root:[80 epochs]: TRAIN:[loss=0.002]
INFO:root:[90 epochs]: TRAIN:[loss=0.002]
INFO:root:[100 epochs]: TRAIN:[loss=0.002]
 35%|████████████▍                       | 104/300 [00:00<00:00, 1033.47epoch/s]INFO:root:[110 epochs]: TRAIN:[loss=0.002]
INFO:root:[120 epochs]: TRAIN:[loss=0.002]
INFO:root:[130 epochs]: TRAIN:[loss=0.002]
INFO:root:[140 epochs]: TRAIN:[loss=0.001]
INFO:root:[150 epochs]: TRAIN:[loss=0.001]
INFO:root:[160 epochs]: TRAIN:[loss=0.001]
INFO:root:[170 epochs]: TRAIN:[loss=0.001]
INFO:root:[180 epochs]: 

In [28]:
np.unique(L_train, return_counts = True)

(array([-1,  0,  1]), array([14317,   585,   530]))

### Label Model Metrics – no tuning
Since our dataset is highly unbalanced (91% of the labels are negative), even a trivial baseline that always outputs negative can get a high accuracy. So we evaluate the label model using the F1 score and ROC-AUC rather than accuracy.

In [29]:
# Init majority vote model.
majority_model = MajorityLabelVoter()

In [30]:
# Compute model performance metrics.
majority_scores = majority_model.score(L = L_test, 
                                       Y = Y_test, 
                                       tie_break_policy = "abstain",
                                       metrics = ["f1", "accuracy", "precision", 
                                                  "recall", "roc_auc", "coverage"])
label_scores = label_model.score(L = L_test, 
                                 Y = Y_test, 
                                 tie_break_policy = "abstain",
                                 metrics = ["f1", "accuracy", "precision", 
                                            "recall", "roc_auc", "coverage"])



In [31]:
# Compare model performance metrics.
print(np.unique(Y_test, return_counts = True))
counts = np.unique(Y_test, return_counts = True)
print("Dummy accuracy:", counts[1].max() / len(Y_test))
print("-----------------------------------")

majority_f1 = majority_scores.get("f1")
majority_acc = majority_scores.get("accuracy")
majority_prec = majority_scores.get("precision")
majority_rec = majority_scores.get("recall")
majority_roc = majority_scores.get("roc_auc")
majority_cov = majority_scores.get("coverage")
print(f"{'Majority Model F1:':<25} {majority_f1 * 100:.1f}%")
print(f"{'Majority Model Accuracy:':<25} {majority_acc * 100:.1f}%")
print(f"{'Majority Model Precision:':<25} {majority_prec * 100:.1f}%")
print(f"{'Majority Model Recall:':<25} {majority_rec * 100:.1f}%")
print(f"{'Majority Model AUC ROC:':<25} {majority_roc * 100:.1f}%")
print(f"{'Majority Model Coverage:':<25} {majority_cov * 100:.1f}%")
print("-----------------------------------")

label_f1 = label_scores.get("f1")
label_acc = label_scores.get("accuracy")
label_prec = label_scores.get("precision")
label_rec = label_scores.get("recall")
label_roc = label_scores.get("roc_auc")
label_cov = label_scores.get("coverage")
print(f"{'Label Model F1:':<25} {label_f1 * 100:.1f}%")
print(f"{'Label Model Accuracy:':<25} {label_acc * 100:.1f}%")
print(f"{'Label Model Precision:':<25} {label_prec * 100:.1f}%")
print(f"{'Label Model Recall:':<25} {label_rec * 100:.1f}%")
print(f"{'Label Model AUC ROC:':<25} {label_roc * 100:.1f}%")
print(f"{'Label Model Coverage:':<25} {label_cov * 100:.1f}%")

(array([0, 1]), array([1020,   81]))
Dummy accuracy: 0.9264305177111717
-----------------------------------
Majority Model F1:        48.9%
Majority Model Accuracy:  65.7%
Majority Model Precision: 34.3%
Majority Model Recall:    85.0%
Majority Model AUC ROC:   73.0%
Majority Model Coverage:  18.8%
-----------------------------------
Label Model F1:           46.6%
Label Model Accuracy:     64.9%
Label Model Precision:    33.7%
Label Model Recall:       75.6%
Label Model AUC ROC:      72.7%
Label Model Coverage:     20.2%


## Model tuning

Tune labeling model [hyperparameters](https://snorkel.readthedocs.io/en/v0.9.7/packages/_autosummary/labeling/snorkel.labeling.model.label_model.LabelModel.html#snorkel.labeling.model.label_model.LabelModel) using grid search.

In [32]:
# Search space.
epochs = [50, 100, 250]
l2 = [0.0, 0.2, 0.4]
lr = [0.001, 0.01, 0.1]
seed = 42

# Take cartesian product to obtain grid search space.
search_space = list(itertools.product(epochs, l2, lr))
print("\n--- HYPERPARAMETER SEARCH SPACE: ---\n")
print("Total combinations:", len(search_space))
print()
print(search_space)


--- HYPERPARAMETER SEARCH SPACE: ---

Total combinations: 27

[(50, 0.0, 0.001), (50, 0.0, 0.01), (50, 0.0, 0.1), (50, 0.2, 0.001), (50, 0.2, 0.01), (50, 0.2, 0.1), (50, 0.4, 0.001), (50, 0.4, 0.01), (50, 0.4, 0.1), (100, 0.0, 0.001), (100, 0.0, 0.01), (100, 0.0, 0.1), (100, 0.2, 0.001), (100, 0.2, 0.01), (100, 0.2, 0.1), (100, 0.4, 0.001), (100, 0.4, 0.01), (100, 0.4, 0.1), (250, 0.0, 0.001), (250, 0.0, 0.01), (250, 0.0, 0.1), (250, 0.2, 0.001), (250, 0.2, 0.01), (250, 0.2, 0.1), (250, 0.4, 0.001), (250, 0.4, 0.01), (250, 0.4, 0.1)]


In [33]:
# Validation loop.
# Select optimal model based on minimizing false negatives,
# i.e. look at recall (syn. sensitivity, TPR), f1, etc.
recalls = []
f1s = []
rocs = []
accuracies = []
precisions = []
for hparams in search_space:
    
    # Extract hyperparameter values.
    n_epochs = hparams[0]
    l2 = hparams[1]
    lr = hparams[2]
    
    # Label model.
    label_model = LabelModel(cardinality = 2, verbose = True)
    label_model.fit(L_train = L_train, 
                    #Y_dev = y_val,
                    #class_balance = [0.7, 0.3], 
                    n_epochs = n_epochs, 
                    l2 = l2,
                    lr = lr,
                    optimizer = "sgd",
                    seed = seed)

    # Compute model performance metrics.
    label_scores = label_model.score(L = L_val, 
                                     Y = y_val, 
                                     tie_break_policy = "abstain",
                                     metrics = ["f1", "accuracy", "precision", 
                                                "recall", "roc_auc", "coverage"])
    
    

    
    print("\n--- HYPERPARAMETERS (epochs, l2, lr): ---")
    print(hparams)

    label_f1 = label_scores.get("f1") * 100
    label_acc = label_scores.get("accuracy") * 100
    label_prec = label_scores.get("precision") * 100
    label_rec = label_scores.get("recall") * 100
    label_roc = label_scores.get("roc_auc") * 100
    label_cov = label_scores.get("coverage") * 100
    print("F1: {} | Accuracy: {} | Precision: {} | Recall: {} | AUC ROC: {} | Coverage: {}".format(round(label_f1, 2),
                                                                                                   round(label_acc, 2),
                                                                                                   round(label_prec, 2),
                                                                                                   round(label_rec, 2),
                                                                                                   round(label_roc, 2),
                                                                                                   round(label_cov, 2)))
    print("----------------------------------------------------")
    
    recalls.append(label_rec)
    f1s.append(label_f1)
    rocs.append(label_roc)
    accuracies.append(label_acc)
    precisions.append(label_prec)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|                                                 | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.016]
INFO:root:[20 epochs]: TRAIN:[loss=0.014]
INFO:root:[30 epochs]: TRAIN:[loss=0.012]
INFO:root:[40 epochs]: TRAIN:[loss=0.010]
100%|███████████████████████████████████████| 50/50 [00:00<00:00, 971.12epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(50, 0.0, 0.001)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                 | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.008]
INFO:root:[20 epochs]: TRAIN:[loss=0.002]
INFO:root:[30 epochs]: TRAIN:[loss=0.002]
INFO:root:[40 epochs]: TRAIN:[loss=0.002]
100%|███████████████████████████████████████| 50/50 [00:00<00:00, 993.48epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(50, 0.0, 0.01)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                 | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.007]
INFO:root:[20 epochs]: TRAIN:[loss=0.004]
INFO:root:[30 epochs]: TRAIN:[loss=0.002]
INFO:root:[40 epochs]: TRAIN:[loss=0.002]
100%|███████████████████████████████████████| 50/50 [00:00<00:00, 968.41epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(50, 0.0, 0.1)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                 | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.017]
INFO:root:[20 epochs]: TRAIN:[loss=0.015]
INFO:root:[30 epochs]: TRAIN:[loss=0.013]
INFO:root:[40 epochs]: TRAIN:[loss=0.011]
100%|███████████████████████████████████████| 50/50 [00:00<00:00, 933.92epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(50, 0.2, 0.001)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                 | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.009]
INFO:root:[20 epochs]: TRAIN:[loss=0.003]
INFO:root:[30 epochs]: TRAIN:[loss=0.003]
INFO:root:[40 epochs]: TRAIN:[loss=0.003]
100%|███████████████████████████████████████| 50/50 [00:00<00:00, 969.23epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(50, 0.2, 0.01)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                 | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.003]
INFO:root:[20 epochs]: TRAIN:[loss=0.005]
INFO:root:[30 epochs]: TRAIN:[loss=0.003]
INFO:root:[40 epochs]: TRAIN:[loss=0.003]
100%|███████████████████████████████████████| 50/50 [00:00<00:00, 973.05epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(50, 0.2, 0.1)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                 | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.021]
INFO:root:[10 epochs]: TRAIN:[loss=0.019]
INFO:root:[20 epochs]: TRAIN:[loss=0.017]
INFO:root:[30 epochs]: TRAIN:[loss=0.014]
INFO:root:[40 epochs]: TRAIN:[loss=0.012]
100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1040.04epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(50, 0.4, 0.001)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                 | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.021]
INFO:root:[10 epochs]: TRAIN:[loss=0.010]
INFO:root:[20 epochs]: TRAIN:[loss=0.005]
INFO:root:[30 epochs]: TRAIN:[loss=0.004]
INFO:root:[40 epochs]: TRAIN:[loss=0.005]
100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1051.28epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(50, 0.4, 0.01)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                 | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.021]
INFO:root:[10 epochs]: TRAIN:[loss=0.004]
INFO:root:[20 epochs]: TRAIN:[loss=0.004]
INFO:root:[30 epochs]: TRAIN:[loss=0.007]
INFO:root:[40 epochs]: TRAIN:[loss=0.005]
100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1034.59epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(50, 0.4, 0.1)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/100 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.016]
INFO:root:[20 epochs]: TRAIN:[loss=0.014]
INFO:root:[30 epochs]: TRAIN:[loss=0.012]
INFO:root:[40 epochs]: TRAIN:[loss=0.010]
INFO:root:[50 epochs]: TRAIN:[loss=0.009]
INFO:root:[60 epochs]: TRAIN:[loss=0.007]
INFO:root:[70 epochs]: TRAIN:[loss=0.006]
INFO:root:[80 epochs]: TRAIN:[loss=0.005]
INFO:root:[90 epochs]: TRAIN:[loss=0.005]
100%|████████████████████████████████████| 100/100 [00:00<00:00, 1025.69epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(100, 0.0, 0.001)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/100 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.008]
INFO:root:[20 epochs]: TRAIN:[loss=0.002]
INFO:root:[30 epochs]: TRAIN:[loss=0.002]
INFO:root:[40 epochs]: TRAIN:[loss=0.002]
INFO:root:[50 epochs]: TRAIN:[loss=0.002]
INFO:root:[60 epochs]: TRAIN:[loss=0.002]
INFO:root:[70 epochs]: TRAIN:[loss=0.002]
INFO:root:[80 epochs]: TRAIN:[loss=0.002]
INFO:root:[90 epochs]: TRAIN:[loss=0.002]
100%|████████████████████████████████████| 100/100 [00:00<00:00, 1097.31epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(100, 0.0, 0.01)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/100 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.007]
INFO:root:[20 epochs]: TRAIN:[loss=0.004]
INFO:root:[30 epochs]: TRAIN:[loss=0.002]
INFO:root:[40 epochs]: TRAIN:[loss=0.002]
INFO:root:[50 epochs]: TRAIN:[loss=0.001]
INFO:root:[60 epochs]: TRAIN:[loss=0.001]
INFO:root:[70 epochs]: TRAIN:[loss=0.001]
INFO:root:[80 epochs]: TRAIN:[loss=0.001]
INFO:root:[90 epochs]: TRAIN:[loss=0.001]
100%|████████████████████████████████████| 100/100 [00:00<00:00, 1164.55epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(100, 0.0, 0.1)
F1: 57.14 | Accuracy: 76.32 | Precision: 45.0 | Recall: 78.26 | AUC ROC: 84.23 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/100 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.017]
INFO:root:[20 epochs]: TRAIN:[loss=0.015]
INFO:root:[30 epochs]: TRAIN:[loss=0.013]
INFO:root:[40 epochs]: TRAIN:[loss=0.011]
INFO:root:[50 epochs]: TRAIN:[loss=0.010]
INFO:root:[60 epochs]: TRAIN:[loss=0.008]
INFO:root:[70 epochs]: TRAIN:[loss=0.007]
INFO:root:[80 epochs]: TRAIN:[loss=0.007]
INFO:root:[90 epochs]: TRAIN:[loss=0.006]
100%|████████████████████████████████████| 100/100 [00:00<00:00, 1174.38epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(100, 0.2, 0.001)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/100 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.009]
INFO:root:[20 epochs]: TRAIN:[loss=0.003]
INFO:root:[30 epochs]: TRAIN:[loss=0.003]
INFO:root:[40 epochs]: TRAIN:[loss=0.003]
INFO:root:[50 epochs]: TRAIN:[loss=0.003]
INFO:root:[60 epochs]: TRAIN:[loss=0.004]
INFO:root:[70 epochs]: TRAIN:[loss=0.003]
INFO:root:[80 epochs]: TRAIN:[loss=0.003]
INFO:root:[90 epochs]: TRAIN:[loss=0.003]
100%|████████████████████████████████████| 100/100 [00:00<00:00, 1179.23epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(100, 0.2, 0.01)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/100 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.003]
INFO:root:[20 epochs]: TRAIN:[loss=0.005]
INFO:root:[30 epochs]: TRAIN:[loss=0.003]
INFO:root:[40 epochs]: TRAIN:[loss=0.003]
INFO:root:[50 epochs]: TRAIN:[loss=0.004]
INFO:root:[60 epochs]: TRAIN:[loss=0.003]
INFO:root:[70 epochs]: TRAIN:[loss=0.003]
INFO:root:[80 epochs]: TRAIN:[loss=0.003]
INFO:root:[90 epochs]: TRAIN:[loss=0.003]
100%|████████████████████████████████████| 100/100 [00:00<00:00, 1180.50epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(100, 0.2, 0.1)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/100 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.021]
INFO:root:[10 epochs]: TRAIN:[loss=0.019]
INFO:root:[20 epochs]: TRAIN:[loss=0.017]
INFO:root:[30 epochs]: TRAIN:[loss=0.014]
INFO:root:[40 epochs]: TRAIN:[loss=0.012]
INFO:root:[50 epochs]: TRAIN:[loss=0.011]
INFO:root:[60 epochs]: TRAIN:[loss=0.010]
INFO:root:[70 epochs]: TRAIN:[loss=0.009]
INFO:root:[80 epochs]: TRAIN:[loss=0.008]
INFO:root:[90 epochs]: TRAIN:[loss=0.008]
100%|████████████████████████████████████| 100/100 [00:00<00:00, 1187.83epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(100, 0.4, 0.001)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/100 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.021]
INFO:root:[10 epochs]: TRAIN:[loss=0.010]
INFO:root:[20 epochs]: TRAIN:[loss=0.005]
INFO:root:[30 epochs]: TRAIN:[loss=0.004]
INFO:root:[40 epochs]: TRAIN:[loss=0.005]
INFO:root:[50 epochs]: TRAIN:[loss=0.006]
INFO:root:[60 epochs]: TRAIN:[loss=0.006]
INFO:root:[70 epochs]: TRAIN:[loss=0.006]
INFO:root:[80 epochs]: TRAIN:[loss=0.005]
INFO:root:[90 epochs]: TRAIN:[loss=0.006]
100%|████████████████████████████████████| 100/100 [00:00<00:00, 1181.78epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(100, 0.4, 0.01)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/100 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.021]
INFO:root:[10 epochs]: TRAIN:[loss=0.004]
INFO:root:[20 epochs]: TRAIN:[loss=0.004]
INFO:root:[30 epochs]: TRAIN:[loss=0.007]
INFO:root:[40 epochs]: TRAIN:[loss=0.005]
INFO:root:[50 epochs]: TRAIN:[loss=0.005]
INFO:root:[60 epochs]: TRAIN:[loss=0.006]
INFO:root:[70 epochs]: TRAIN:[loss=0.006]
INFO:root:[80 epochs]: TRAIN:[loss=0.005]
INFO:root:[90 epochs]: TRAIN:[loss=0.006]
100%|████████████████████████████████████| 100/100 [00:00<00:00, 1185.68epoch/s]
INFO:root:Finished Training
INFO:root:Computing O...
INFO:root:Estimating \mu...



--- HYPERPARAMETERS (epochs, l2, lr): ---
(100, 0.4, 0.1)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/250 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.016]
INFO:root:[20 epochs]: TRAIN:[loss=0.014]
INFO:root:[30 epochs]: TRAIN:[loss=0.012]
INFO:root:[40 epochs]: TRAIN:[loss=0.010]
INFO:root:[50 epochs]: TRAIN:[loss=0.009]
INFO:root:[60 epochs]: TRAIN:[loss=0.007]
INFO:root:[70 epochs]: TRAIN:[loss=0.006]
INFO:root:[80 epochs]: TRAIN:[loss=0.005]
INFO:root:[90 epochs]: TRAIN:[loss=0.005]
INFO:root:[100 epochs]: TRAIN:[loss=0.004]
INFO:root:[110 epochs]: TRAIN:[loss=0.004]
 47%|████████████████▉                   | 118/250 [00:00<00:00, 1176.25epoch/s]INFO:root:[120 epochs]: TRAIN:[loss=0.003]
INFO:root:[130 epochs]: TRAIN:[loss=0.003]
INFO:root:[140 epochs]: TRAIN:[loss=0.003]
INFO:root:[150 epochs]: TRAIN:[loss=0.003]
INFO:root:[160 epochs]: TRAIN:[loss=0.002]
INFO:root:[170 epochs]: TRAIN:[loss=0.002]
INFO:root:[180 epochs]: TRAIN:[loss=0.002]
INFO:root:[190 epochs]: TRAIN:[los


--- HYPERPARAMETERS (epochs, l2, lr): ---
(250, 0.0, 0.001)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/250 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.008]
INFO:root:[20 epochs]: TRAIN:[loss=0.002]
INFO:root:[30 epochs]: TRAIN:[loss=0.002]
INFO:root:[40 epochs]: TRAIN:[loss=0.002]
INFO:root:[50 epochs]: TRAIN:[loss=0.002]
INFO:root:[60 epochs]: TRAIN:[loss=0.002]
INFO:root:[70 epochs]: TRAIN:[loss=0.002]
INFO:root:[80 epochs]: TRAIN:[loss=0.002]
INFO:root:[90 epochs]: TRAIN:[loss=0.002]
INFO:root:[100 epochs]: TRAIN:[loss=0.002]
INFO:root:[110 epochs]: TRAIN:[loss=0.002]
INFO:root:[120 epochs]: TRAIN:[loss=0.002]
 49%|█████████████████▌                  | 122/250 [00:00<00:00, 1203.82epoch/s]INFO:root:[130 epochs]: TRAIN:[loss=0.002]
INFO:root:[140 epochs]: TRAIN:[loss=0.002]
INFO:root:[150 epochs]: TRAIN:[loss=0.002]
INFO:root:[160 epochs]: TRAIN:[loss=0.002]
INFO:root:[170 epochs]: TRAIN:[loss=0.002]
INFO:root:[180 epochs]: TRAIN:[loss=0.002]
INFO:root:[190 epochs]: TRAIN:[los


--- HYPERPARAMETERS (epochs, l2, lr): ---
(250, 0.0, 0.01)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/250 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.007]
INFO:root:[20 epochs]: TRAIN:[loss=0.004]
INFO:root:[30 epochs]: TRAIN:[loss=0.002]
INFO:root:[40 epochs]: TRAIN:[loss=0.002]
INFO:root:[50 epochs]: TRAIN:[loss=0.001]
INFO:root:[60 epochs]: TRAIN:[loss=0.001]
INFO:root:[70 epochs]: TRAIN:[loss=0.001]
INFO:root:[80 epochs]: TRAIN:[loss=0.001]
INFO:root:[90 epochs]: TRAIN:[loss=0.001]
INFO:root:[100 epochs]: TRAIN:[loss=0.001]
INFO:root:[110 epochs]: TRAIN:[loss=0.001]
INFO:root:[120 epochs]: TRAIN:[loss=0.001]
 49%|█████████████████▋                  | 123/250 [00:00<00:00, 1229.14epoch/s]INFO:root:[130 epochs]: TRAIN:[loss=0.001]
INFO:root:[140 epochs]: TRAIN:[loss=0.001]
INFO:root:[150 epochs]: TRAIN:[loss=0.001]
INFO:root:[160 epochs]: TRAIN:[loss=0.001]
INFO:root:[170 epochs]: TRAIN:[loss=0.001]
INFO:root:[180 epochs]: TRAIN:[loss=0.001]
INFO:root:[190 epochs]: TRAIN:[los


--- HYPERPARAMETERS (epochs, l2, lr): ---
(250, 0.0, 0.1)
F1: 0.0 | Accuracy: 79.82 | Precision: 0.0 | Recall: 0.0 | AUC ROC: 73.24 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/250 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.017]
INFO:root:[20 epochs]: TRAIN:[loss=0.015]
INFO:root:[30 epochs]: TRAIN:[loss=0.013]
INFO:root:[40 epochs]: TRAIN:[loss=0.011]
INFO:root:[50 epochs]: TRAIN:[loss=0.010]
INFO:root:[60 epochs]: TRAIN:[loss=0.008]
INFO:root:[70 epochs]: TRAIN:[loss=0.007]
INFO:root:[80 epochs]: TRAIN:[loss=0.007]
INFO:root:[90 epochs]: TRAIN:[loss=0.006]
INFO:root:[100 epochs]: TRAIN:[loss=0.005]
INFO:root:[110 epochs]: TRAIN:[loss=0.005]
 44%|███████████████▉                    | 111/250 [00:00<00:00, 1098.68epoch/s]INFO:root:[120 epochs]: TRAIN:[loss=0.005]
INFO:root:[130 epochs]: TRAIN:[loss=0.005]
INFO:root:[140 epochs]: TRAIN:[loss=0.004]
INFO:root:[150 epochs]: TRAIN:[loss=0.004]
INFO:root:[160 epochs]: TRAIN:[loss=0.004]
INFO:root:[170 epochs]: TRAIN:[loss=0.004]
INFO:root:[180 epochs]: TRAIN:[loss=0.004]
INFO:root:[190 epochs]: TRAIN:[los


--- HYPERPARAMETERS (epochs, l2, lr): ---
(250, 0.2, 0.001)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/250 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.009]
INFO:root:[20 epochs]: TRAIN:[loss=0.003]
INFO:root:[30 epochs]: TRAIN:[loss=0.003]
INFO:root:[40 epochs]: TRAIN:[loss=0.003]
INFO:root:[50 epochs]: TRAIN:[loss=0.003]
INFO:root:[60 epochs]: TRAIN:[loss=0.004]
INFO:root:[70 epochs]: TRAIN:[loss=0.003]
INFO:root:[80 epochs]: TRAIN:[loss=0.003]
INFO:root:[90 epochs]: TRAIN:[loss=0.003]
INFO:root:[100 epochs]: TRAIN:[loss=0.003]
INFO:root:[110 epochs]: TRAIN:[loss=0.003]
 44%|███████████████▉                    | 111/250 [00:00<00:00, 1108.04epoch/s]INFO:root:[120 epochs]: TRAIN:[loss=0.003]
INFO:root:[130 epochs]: TRAIN:[loss=0.003]
INFO:root:[140 epochs]: TRAIN:[loss=0.003]
INFO:root:[150 epochs]: TRAIN:[loss=0.003]
INFO:root:[160 epochs]: TRAIN:[loss=0.003]
INFO:root:[170 epochs]: TRAIN:[loss=0.003]
INFO:root:[180 epochs]: TRAIN:[loss=0.003]
INFO:root:[190 epochs]: TRAIN:[los


--- HYPERPARAMETERS (epochs, l2, lr): ---
(250, 0.2, 0.01)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/250 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.003]
INFO:root:[20 epochs]: TRAIN:[loss=0.005]
INFO:root:[30 epochs]: TRAIN:[loss=0.003]
INFO:root:[40 epochs]: TRAIN:[loss=0.003]
INFO:root:[50 epochs]: TRAIN:[loss=0.004]
INFO:root:[60 epochs]: TRAIN:[loss=0.003]
INFO:root:[70 epochs]: TRAIN:[loss=0.003]
INFO:root:[80 epochs]: TRAIN:[loss=0.003]
INFO:root:[90 epochs]: TRAIN:[loss=0.003]
INFO:root:[100 epochs]: TRAIN:[loss=0.003]
INFO:root:[110 epochs]: TRAIN:[loss=0.003]
 45%|████████████████▏                   | 112/250 [00:00<00:00, 1106.48epoch/s]INFO:root:[120 epochs]: TRAIN:[loss=0.003]
INFO:root:[130 epochs]: TRAIN:[loss=0.003]
INFO:root:[140 epochs]: TRAIN:[loss=0.003]
INFO:root:[150 epochs]: TRAIN:[loss=0.003]
INFO:root:[160 epochs]: TRAIN:[loss=0.003]
INFO:root:[170 epochs]: TRAIN:[loss=0.003]
INFO:root:[180 epochs]: TRAIN:[loss=0.003]
INFO:root:[190 epochs]: TRAIN:[los


--- HYPERPARAMETERS (epochs, l2, lr): ---
(250, 0.2, 0.1)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/250 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.021]
INFO:root:[10 epochs]: TRAIN:[loss=0.019]
INFO:root:[20 epochs]: TRAIN:[loss=0.017]
INFO:root:[30 epochs]: TRAIN:[loss=0.014]
INFO:root:[40 epochs]: TRAIN:[loss=0.012]
INFO:root:[50 epochs]: TRAIN:[loss=0.011]
INFO:root:[60 epochs]: TRAIN:[loss=0.010]
INFO:root:[70 epochs]: TRAIN:[loss=0.009]
INFO:root:[80 epochs]: TRAIN:[loss=0.008]
INFO:root:[90 epochs]: TRAIN:[loss=0.008]
INFO:root:[100 epochs]: TRAIN:[loss=0.007]
INFO:root:[110 epochs]: TRAIN:[loss=0.007]
 48%|█████████████████▎                  | 120/250 [00:00<00:00, 1187.51epoch/s]INFO:root:[120 epochs]: TRAIN:[loss=0.007]
INFO:root:[130 epochs]: TRAIN:[loss=0.006]
INFO:root:[140 epochs]: TRAIN:[loss=0.006]
INFO:root:[150 epochs]: TRAIN:[loss=0.006]
INFO:root:[160 epochs]: TRAIN:[loss=0.006]
INFO:root:[170 epochs]: TRAIN:[loss=0.006]
INFO:root:[180 epochs]: TRAIN:[loss=0.006]
INFO:root:[190 epochs]: TRAIN:[los


--- HYPERPARAMETERS (epochs, l2, lr): ---
(250, 0.4, 0.001)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/250 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.021]
INFO:root:[10 epochs]: TRAIN:[loss=0.010]
INFO:root:[20 epochs]: TRAIN:[loss=0.005]
INFO:root:[30 epochs]: TRAIN:[loss=0.004]
INFO:root:[40 epochs]: TRAIN:[loss=0.005]
INFO:root:[50 epochs]: TRAIN:[loss=0.006]
INFO:root:[60 epochs]: TRAIN:[loss=0.006]
INFO:root:[70 epochs]: TRAIN:[loss=0.006]
INFO:root:[80 epochs]: TRAIN:[loss=0.005]
INFO:root:[90 epochs]: TRAIN:[loss=0.006]
INFO:root:[100 epochs]: TRAIN:[loss=0.006]
INFO:root:[110 epochs]: TRAIN:[loss=0.006]
 46%|████████████████▍                   | 114/250 [00:00<00:00, 1139.70epoch/s]INFO:root:[120 epochs]: TRAIN:[loss=0.006]
INFO:root:[130 epochs]: TRAIN:[loss=0.006]
INFO:root:[140 epochs]: TRAIN:[loss=0.006]
INFO:root:[150 epochs]: TRAIN:[loss=0.006]
INFO:root:[160 epochs]: TRAIN:[loss=0.006]
INFO:root:[170 epochs]: TRAIN:[loss=0.006]
INFO:root:[180 epochs]: TRAIN:[loss=0.006]
INFO:root:[190 epochs]: TRAIN:[los


--- HYPERPARAMETERS (epochs, l2, lr): ---
(250, 0.4, 0.01)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


  0%|                                                | 0/250 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.021]
INFO:root:[10 epochs]: TRAIN:[loss=0.004]
INFO:root:[20 epochs]: TRAIN:[loss=0.004]
INFO:root:[30 epochs]: TRAIN:[loss=0.007]
INFO:root:[40 epochs]: TRAIN:[loss=0.005]
INFO:root:[50 epochs]: TRAIN:[loss=0.005]
INFO:root:[60 epochs]: TRAIN:[loss=0.006]
INFO:root:[70 epochs]: TRAIN:[loss=0.006]
INFO:root:[80 epochs]: TRAIN:[loss=0.005]
INFO:root:[90 epochs]: TRAIN:[loss=0.006]
INFO:root:[100 epochs]: TRAIN:[loss=0.006]
INFO:root:[110 epochs]: TRAIN:[loss=0.006]
INFO:root:[120 epochs]: TRAIN:[loss=0.006]
INFO:root:[130 epochs]: TRAIN:[loss=0.006]
INFO:root:[140 epochs]: TRAIN:[loss=0.006]
INFO:root:[150 epochs]: TRAIN:[loss=0.006]
INFO:root:[160 epochs]: TRAIN:[loss=0.006]
 65%|███████████████████████▎            | 162/250 [00:00<00:00, 1600.52epoch/s]INFO:root:[170 epochs]: TRAIN:[loss=0.006]
INFO:root:[180 epochs]: TRAIN:[loss=0.006]
INFO:root:[190 epochs]: TRAIN:[los


--- HYPERPARAMETERS (epochs, l2, lr): ---
(250, 0.4, 0.1)
F1: 59.38 | Accuracy: 77.19 | Precision: 46.34 | Recall: 82.61 | AUC ROC: 82.99 | Coverage: 20.61
----------------------------------------------------


In [34]:
# Get hyparameter combinations with maximum metrics.
max_recall = max(recalls)
max_recall_indices = [i for i,j in enumerate(recalls) if j == max_recall]
hparams_max_recall = [search_space[i] for i in max_recall_indices]
print("\n--- HYPERPARAMETERS FOR MAX RECALL OF {}: ---\n".format(max_recall))
print("Total models with optimal recall:", len(hparams_max_recall))
print()
print(hparams_max_recall)

max_f1 = max(f1s)
max_f1_indices = [i for i,j in enumerate(f1s) if j == max_f1]
hparams_max_f1 = [search_space[i] for i in max_f1_indices]
print("\n--- HYPERPARAMETERS FOR MAX F1 OF {}: ---\n".format(max_f1))
print("Total models with optimal F1:", len(hparams_max_f1))
print()
print(hparams_max_f1)

max_accuracies = max(accuracies)
max_accuracies_indices = [i for i,j in enumerate(accuracies) if j == max_accuracies]
hparams_max_accuracies = [search_space[i] for i in max_accuracies_indices]
print("\n--- HYPERPARAMETERS FOR MAX ACCURACY OF {}: ---\n".format(max_accuracies))
print("Total models with optimal accuracy:", len(hparams_max_accuracies))
print()
print(hparams_max_accuracies)

max_precision = max(precisions)
max_precision_indices = [i for i,j in enumerate(precisions) if j == max_precision]
hparams_max_precision = [search_space[i] for i in max_precision_indices]
print("\n--- HYPERPARAMETERS FOR MAX PRECISION OF {}: ---\n".format(max_precision))
print("Total models with optimal precision:", len(hparams_max_precision))
print()
print(hparams_max_precision)

max_roc = max(rocs)
max_roc_indices = [i for i,j in enumerate(rocs) if j == max_roc]
hparams_max_roc = [search_space[i] for i in max_roc_indices]
print("\n--- HYPERPARAMETERS FOR MAX AUC ROC OF {}: ---\n".format(max_roc))
print("Total models with optimal AUC ROC:", len(hparams_max_roc))
print()
print(hparams_max_roc)


--- HYPERPARAMETERS FOR MAX RECALL OF 82.6086956521739: ---

Total models with optimal recall: 25

[(50, 0.0, 0.001), (50, 0.0, 0.01), (50, 0.0, 0.1), (50, 0.2, 0.001), (50, 0.2, 0.01), (50, 0.2, 0.1), (50, 0.4, 0.001), (50, 0.4, 0.01), (50, 0.4, 0.1), (100, 0.0, 0.001), (100, 0.0, 0.01), (100, 0.2, 0.001), (100, 0.2, 0.01), (100, 0.2, 0.1), (100, 0.4, 0.001), (100, 0.4, 0.01), (100, 0.4, 0.1), (250, 0.0, 0.001), (250, 0.0, 0.01), (250, 0.2, 0.001), (250, 0.2, 0.01), (250, 0.2, 0.1), (250, 0.4, 0.001), (250, 0.4, 0.01), (250, 0.4, 0.1)]

--- HYPERPARAMETERS FOR MAX F1 OF 59.375: ---

Total models with optimal F1: 25

[(50, 0.0, 0.001), (50, 0.0, 0.01), (50, 0.0, 0.1), (50, 0.2, 0.001), (50, 0.2, 0.01), (50, 0.2, 0.1), (50, 0.4, 0.001), (50, 0.4, 0.01), (50, 0.4, 0.1), (100, 0.0, 0.001), (100, 0.0, 0.01), (100, 0.2, 0.001), (100, 0.2, 0.01), (100, 0.2, 0.1), (100, 0.4, 0.001), (100, 0.4, 0.01), (100, 0.4, 0.1), (250, 0.0, 0.001), (250, 0.0, 0.01), (250, 0.2, 0.001), (250, 0.2, 0.01), (

In [35]:
# Assess intersection of best models.
hparams_max = [hparams_max_accuracies, 
               hparams_max_f1, 
               hparams_max_recall,
               hparams_max_roc,
               hparams_max_precision]
best_intersect = set.intersection(*map(set, hparams_max))
print("Total models at intersection:", len(best_intersect))
print(best_intersect)

Total models at intersection: 0
set()


In [36]:
# Hand select best model in absence of model at intersection.
optimal_hparams = [100, 0.0, 0.01]
print("Optimal hyperparameters =", optimal_hparams)

Optimal hyperparameters = [100, 0.0, 0.01]


## Train optimal labeling model

In [37]:
# Extract hyperparameter values.
n_epochs = optimal_hparams[0]
l2 = optimal_hparams[1]
lr = optimal_hparams[2]
    
# Label model.
label_model = LabelModel(cardinality = 2, verbose = True)
label_model.fit(L_train = L_train, 
                #Y_dev = y_val,
                #class_balance = [0.7, 0.3], 
                n_epochs = n_epochs, 
                l2 = l2,
                lr = lr,
                optimizer = "sgd",
                seed = seed)

# Compute model performance metrics.
label_scores = label_model.score(L = L_test,
                                 Y = y_test, 
                                 tie_break_policy = "abstain",
                                 metrics = ["f1", "accuracy", "precision", 
                                            "recall", "roc_auc", "coverage"])
    
    

# Compare model performance metrics.
majority_f1 = majority_scores.get("f1")
majority_acc = majority_scores.get("accuracy")
majority_prec = majority_scores.get("precision")
majority_rec = majority_scores.get("recall")
majority_roc = majority_scores.get("roc_auc")
majority_cov = majority_scores.get("coverage")
print(f"{'Majority Model F1:':<25} {majority_f1 * 100:.1f}%")
print(f"{'Majority Model Accuracy:':<25} {majority_acc * 100:.1f}%")
print(f"{'Majority Model Precision:':<25} {majority_prec * 100:.1f}%")
print(f"{'Majority Model Recall:':<25} {majority_rec * 100:.1f}%")
print(f"{'Majority Model AUC ROC:':<25} {majority_roc * 100:.1f}%")
print(f"{'Majority Model Coverage:':<25} {majority_cov * 100:.1f}%")
print("--------------------------------------")
label_f1 = label_scores.get("f1")
label_acc = label_scores.get("accuracy")
label_prec = label_scores.get("precision")
label_rec = label_scores.get("recall")
label_roc = label_scores.get("roc_auc")
label_cov = label_scores.get("coverage")
print(f"{'Label Model F1:':<25} {label_f1 * 100:.1f}%")
print(f"{'Label Model Accuracy:':<25} {label_acc * 100:.1f}%")
print(f"{'Label Model Precision:':<25} {label_prec * 100:.1f}%")
print(f"{'Label Model Recall:':<25} {label_rec * 100:.1f}%")
print(f"{'Label Model AUC ROC:':<25} {label_roc * 100:.1f}%")
print(f"{'Label Model Coverage:':<25} {label_cov * 100:.1f}%")

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|                                                | 0/100 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.018]
INFO:root:[10 epochs]: TRAIN:[loss=0.008]
INFO:root:[20 epochs]: TRAIN:[loss=0.002]
INFO:root:[30 epochs]: TRAIN:[loss=0.002]
INFO:root:[40 epochs]: TRAIN:[loss=0.002]
INFO:root:[50 epochs]: TRAIN:[loss=0.002]
INFO:root:[60 epochs]: TRAIN:[loss=0.002]
INFO:root:[70 epochs]: TRAIN:[loss=0.002]
INFO:root:[80 epochs]: TRAIN:[loss=0.002]
INFO:root:[90 epochs]: TRAIN:[loss=0.002]
100%|████████████████████████████████████| 100/100 [00:00<00:00, 1067.51epoch/s]
INFO:root:Finished Training


Majority Model F1:        48.9%
Majority Model Accuracy:  65.7%
Majority Model Precision: 34.3%
Majority Model Recall:    85.0%
Majority Model AUC ROC:   73.0%
Majority Model Coverage:  18.8%
--------------------------------------
Label Model F1:           46.6%
Label Model Accuracy:     64.9%
Label Model Precision:    33.7%
Label Model Recall:       75.6%
Label Model AUC ROC:      73.3%
Label Model Coverage:     20.2%


## Export labeling function matrices

In [38]:
# Get predicted labels.
pred_train = label_model.predict(L_train)
pred_val = label_model.predict(L_val)
pred_test = label_model.predict(L_test)

print("TRAIN :\n", pd.Series(pred_train).value_counts())
print("VAL   :\n", pd.Series(pred_val).value_counts())
print("TEST  :\n", pd.Series(pred_test).value_counts())

print("TRAIN :\n", pd.Series(pred_train).value_counts(normalize = True))
print("VAL   :\n", pd.Series(pred_val).value_counts(normalize = True))
print("TEST  :\n", pd.Series(pred_test).value_counts(normalize = True))

TRAIN :
 -1    3019
 0    465 
 1    374 
dtype: int64
VAL   :
 -1    439
 0    73 
 1    41 
dtype: int64
TEST  :
 -1    879
 0    121
 1    101
dtype: int64
TRAIN :
 -1    0.782530
 0    0.120529
 1    0.096941
dtype: float64
VAL   :
 -1    0.793852
 0    0.132007
 1    0.074141
dtype: float64
TEST  :
 -1    0.798365
 0    0.109900
 1    0.091735
dtype: float64


In [39]:
# Replace -1 with 0 (abstain = 0 in our model).
# Replace 0 with -1 (-1 is class label in our model).
pred_train = np.where(pred_train == -1, 5, pred_train)
pred_train = np.where(pred_train == 0, -1, pred_train)
pred_train = np.where(pred_train == 5, 0, pred_train)

pred_val = np.where(pred_val == -1, 5, pred_val)
pred_val = np.where(pred_val == 0, -1, pred_val)
pred_val = np.where(pred_val == 5, 0, pred_val)

pred_test = np.where(pred_test == -1, 5, pred_test)
pred_test = np.where(pred_test == 0, -1, pred_test)
pred_test = np.where(pred_test == 5, 0, pred_test)

print("TRAIN :\n", pd.Series(pred_train).value_counts())
print("VAL   :\n", pd.Series(pred_val).value_counts())
print("TEST  :\n", pd.Series(pred_test).value_counts())

TRAIN :
  0    3019
-1    465 
 1    374 
dtype: int64
VAL   :
  0    439
-1    73 
 1    41 
dtype: int64
TEST  :
  0    879
-1    121
 1    101
dtype: int64


In [40]:
# Replace -1 with 0 (abstain = 0 in our model).
# Replace 0 with -1 (-1 is class label in our model).

################################################################
# TRAIN.
L_train_new = np.where(L_train == -1, 5, L_train)
L_train_new = np.where(L_train_new == 0, -1, L_train_new)
L_train_new = np.where(L_train_new == 5, 0, L_train_new)
y_train_new = np.where(y_train == 0, -1, y_train)

df_L_train = pd.DataFrame(L_train_new)
df_L_train["Label"] = y_train_new
df_L_train["Snorkel"] = pred_train
display(df_L_train)
################################################################

################################################################
# VAL.
L_val_new = np.where(L_val == -1, 5, L_val)
L_val_new = np.where(L_val_new == 0, -1, L_val_new)
L_val_new = np.where(L_val_new == 5, 0, L_val_new)
y_val_new = np.where(y_val == 0, -1, y_val)

df_L_val = pd.DataFrame(L_val_new)
df_L_val["Label"] = y_val_new
df_L_val["Snorkel"] = pred_val
display(df_L_val)
################################################################

################################################################
# TEST.
L_test_new = np.where(L_test == -1, 5, L_test)
L_test_new = np.where(L_test_new == 0, -1, L_test_new)
L_test_new = np.where(L_test_new == 5, 0, L_test_new)
y_test_new = np.where(y_test == 0, -1, y_test)

df_L_test = pd.DataFrame(L_test_new)
df_L_test["Label"] = y_test_new
df_L_test["Snorkel"] = pred_test
display(df_L_test)
################################################################

Unnamed: 0,0,1,2,3,Label,Snorkel
0,0,0,0,0,-1,0
1,0,0,0,0,-1,0
2,0,0,-1,0,-1,-1
3,0,0,0,0,-1,0
4,0,0,0,0,-1,0
...,...,...,...,...,...,...
3853,0,0,-1,0,-1,-1
3854,0,0,0,0,-1,0
3855,0,0,0,0,-1,0
3856,0,0,0,0,-1,0


Unnamed: 0,0,1,2,3,Label,Snorkel
0,0,0,0,0,-1,0
1,0,0,0,0,-1,0
2,0,0,0,0,-1,0
3,0,0,0,0,-1,0
4,0,0,0,0,-1,0
...,...,...,...,...,...,...
548,0,0,0,0,-1,0
549,1,1,0,0,-1,1
550,0,0,0,0,-1,0
551,0,0,0,0,-1,0


Unnamed: 0,0,1,2,3,Label,Snorkel
0,0,0,0,0,1,0
1,0,0,0,0,-1,0
2,0,0,-1,-1,-1,-1
3,0,0,0,0,-1,0
4,0,0,0,0,-1,0
...,...,...,...,...,...,...
1096,0,0,0,0,-1,0
1097,0,0,0,0,-1,0
1098,0,0,0,0,-1,0
1099,0,0,0,0,-1,0


In [None]:
# Export labeling function matrix for external experiments.
df_L_train.to_csv("spouse_tuned_lf_matrix_train.csv", index = False)
df_L_val.to_csv("spouse_tuned_lf_matrix_val.csv", index = False)
df_L_test.to_csv("spouse_tuned_lf_matrix_test.csv", index = False)

In [None]:
# Export labeling function matrix for external experiments.
X_train.to_csv("spouse_tuned_train.csv", index = False)
X_val.to_csv("spouse_tuned_val.csv", index = False)
X_test.to_csv("spouse_tuned_test.csv", index = False)

### End of document