In [1]:
import time
from examples.prediction_and_evaluation import pred_and_eval_gen_model, eval_majority_vote
from examples.utils import change_labels
from factor_graph import FactorGraph
import numpy as np

# Comparing the implemented factor graph against Snorkel (a latent MRF model)

## The data used consists of:
 - labels Y for the created task of discriminating professors from teachers in the Bias in Bios dataset
 - 99 selected labeling functions, usable for a standard data programming framework

In [2]:
def train_supervised(label_matrix, Y_true, dependencies, lf_prop=True, n_epoch=25, lr=0.1):
    start_t = time.time()
    """ Get polarities of each LF, ASSUMPTION: Each LF only votes for ONE label, and abstains otherwise"""
    polarities = np.sign(np.sum(label_matrix, axis=0))
    """ In the supervised case, the data fed into the PGM Learning will just be all concatenated """
    observations = np.concatenate((Y_true.reshape((-1, 1)), label_matrix), axis=1)
    """ Create a MRF with fully observed variables"""
    lm = FactorGraph(n_LFs=label_matrix.shape[1], LF_polarities=polarities, deps=dependencies)
    lm.fit(observations, lr=lr, n_epochs=n_epoch, batch_size=250)
    """ Evaluate the learned generative model """
    stat, probs = pred_and_eval_gen_model(lm, observations, Y_true, version=99, abst=0, verbose=True, print_MV=False,
                                          eps=0.0, return_preds=True, coverage_stats=False, add_prefix="")
    duration = time.time() - start_t
    print(f"Time needed by generative model: {duration}")
    # Will train the downstream classifier:
    # stat_cl = train_and_eval_classifier(Xtrain, Xtest, probs, Ytest, label_matrix, library='torch',
    #                                    optim='Adam', devicestring=device, epochs=250, print_step=505)
    return lm, stat, probs

In [3]:
def train_snorkel(label_matrix, Y_true, n_epoch=1000, lr=0.1):
    from snorkel.labeling.model import LabelModel
    # LABEL MODEL
    start_t = time.time()
    """ Snorkel requires abstention label to be -1..."""
    label_matrix, Y_true = change_labels(label_matrix, Y_true, new_label=-1, old_label=0)
    """ Train latent label model from Snorkel """
    lm = LabelModel(cardinality=2)
    lm.fit(label_matrix, n_epochs=n_epoch, seed=77, lr=lr)
    """ Evaluate the learned generative model """
    stat, probs = pred_and_eval_gen_model(lm, label_matrix, Y_true, abst=-1, verbose=True,
                                          print_MV=False, eps=0.0, MV_policy="random",
                                          return_preds=True, version=10, coverage_stats=False)

    duration = time.time() - start_t
    print(f"Time needed by Snorkel's generative model: {duration}")
    # Will train the downstream classifier:
    # stat_cl = train_and_eval_classifier(Xtrain, Xtest, probs, Ytest, label_matrix, library='torch',
    #                                    optim='Adam', devicestring=device, epochs=250, print_step=505)
    return lm, stat, probs

In [6]:
seed = 77
n_runs = 5
data = np.load("../data/professor_vs_teacher_99LFs.npz")
L_arr, Ytrain = data["L"], data["Y"]

In [5]:
print("---------------------------------- MAJORITY VOTE STATS --------------------------------------------------")
print("MV on all samples with ", L_arr.shape[1], "LFs")
eval_majority_vote(L_arr, Ytrain, abst=0, MV_policy='random')
print("---------------------------------------------------------------------------------------------------------")
# PRINT LF descriptions: [print(d) for d in descr]
lfprop = False
n_samples, nlf = L_arr.shape

---------------------------------- MAJORITY VOTE STATS --------------------------------------------------
MV on all samples with  99 LFs
Majority vote stats:
Accuracy:0.754 | Precision:0.772 | Recall:0.713 | F1 score:0.742 | AUC:0.796 | Log loss:5.506 | Brier:0.917 | Coverage:1.000 | MSE, MAE:0.917, 0.751
---------------------------------------------------------------------------------------------------------


# Supervised (ours)

In [9]:
_, _, _ = train_supervised(L_arr, Ytrain, [],  lf_prop=lfprop, lr=0.1, n_epoch=10)

Epoch 0...
Accuracy:0.787 | Precision:0.754 | Recall:0.845 | F1 score:0.797 | AUC:0.890 | Log loss:1.182 | Brier:0.949 | Coverage:1.000 | MSE, MAE:0.949, 0.726
Time needed by generative model: 59.90506148338318


In [6]:
_, _, _ = train_supervised(L_arr, Ytrain, [],  lf_prop=lfprop, lr=0.1, n_epoch=10)
_, _, _ = train_supervised(L_arr, Ytrain, [], lf_prop=lfprop, lr=0.1, n_epoch=25)
_, _, _ = train_supervised(L_arr, Ytrain, [], lf_prop=lfprop, lr=0.1, n_epoch=25)


Epoch 0...
Accuracy:0.787 | Precision:0.763 | Recall:0.828 | F1 score:0.794 | AUC:0.883 | Log loss:1.142 | Brier:0.938 | Coverage:1.000 | MSE, MAE:0.938, 0.730
Time needed by generative model: 68.06458353996277
Epoch 0...
Accuracy:0.783 | Precision:0.760 | Recall:0.820 | F1 score:0.789 | AUC:0.876 | Log loss:1.880 | Brier:0.947 | Coverage:1.000 | MSE, MAE:0.947, 0.730
Time needed by generative model: 150.51205468177795
Epoch 0...
Accuracy:0.781 | Precision:0.736 | Recall:0.870 | F1 score:0.797 | AUC:0.886 | Log loss:2.385 | Brier:0.987 | Coverage:1.000 | MSE, MAE:0.987, 0.727
Time needed by generative model: 159.4011058807373


# Snorkel
### Note that this is the newer, faster snorkel. (the old snorkel using SGD+MLE is similarly slow or slower)

In [8]:
_, _, _ = train_snorkel(L_arr, Ytrain, lr=0.01, n_epoch=1000)
_, _, _ = train_snorkel(L_arr, Ytrain, lr=0.01, n_epoch=1000)
_, _, _ = train_snorkel(L_arr, Ytrain, lr=0.1, n_epoch=1000)


Accuracy:0.769 | Precision:0.718 | Recall:0.878 | F1 score:0.790 | AUC:0.880 | Log loss:0.678 | Brier:0.159 | Coverage:1.000 | MSE, MAE:0.159, 0.275
Time needed by Snorkel's generative model: 3.0689585208892822
Accuracy:0.769 | Precision:0.718 | Recall:0.878 | F1 score:0.790 | AUC:0.880 | Log loss:0.678 | Brier:0.159 | Coverage:1.000 | MSE, MAE:0.159, 0.275
Time needed by Snorkel's generative model: 2.7420010566711426
Accuracy:0.769 | Precision:0.718 | Recall:0.878 | F1 score:0.790 | AUC:0.880 | Log loss:0.678 | Brier:0.159 | Coverage:1.000 | MSE, MAE:0.159, 0.275
Time needed by Snorkel's generative model: 2.7009992599487305
