# Crowdsourcing tutorial
In this tutorial, we'll provide a simple walkthrough of how to use Snorkel to resolve conflicts
in a noisy, hybrid dataset for a sentiment analysis task.
We have crowdsourced labels for a portion of the training dataset, and combine these
with heuristic labeling functions to increase the number of training labels we have.
Like most Snorkel labeling pipelines, we'll use the denoised labels to train a deep learning
model which can be applied to new, unseen data to automatically make predictions!

In this tutorial, we're using the
[Weather Sentiment](https://data.world/crowdflower/weather-sentiment)
dataset from Figure Eight.
Our goal is to label each tweet as either positive or negative so that
we can train a language model over the tweets themselves that can be applied
to new, unseen data points.
Crowd workers were asked to grade the sentiment of a
particular tweet relating to the weather.
Crowd workers could choose among the following categories:

* Positive
* Negative
* I can't tell
* Neutral / author is just sharing information
* Tweet not related to weather condition

The catch is that 20 crowd workers graded each tweet, and in many cases
crowd workers assigned conflicting sentiment labels to the same tweet.
This is a common issue when dealing with crowdsourced labeling workloads.

We've also altered the data set to reflect a realistic crowdsourcing pipeline
where only a subset of our full training set have recieved crowd labels.
We'll encode the crowd labels themselves as labeling functions in order
to learn trust weights for each crowd worker, and write a few heuristic
labeling functions to cover the data points without crowd labels.
Snorkel's ability to build high-quality datasets from multiple noisy labeling
signals makes it an ideal framework to approach this problem.

We start by loading our data which has 287 examples in total.
We take 50 for our development set and 50 for our test set.
The remaining 187 examples form our training set.
This data set is very small, and we're primarily using it for demonstration purposes.
In particular, we'd expect to have access to many more unlabeled tweets in order to
train a high performance text model.

The labels above have been mapped to integers, which we show here.

## Loading Crowdsourcing Dataset

In [1]:
import os

if os.path.basename(os.getcwd()) == "snorkel-tutorials":
    os.chdir("crowdsourcing")

In [2]:
from data import load_data, answer_mapping

crowd_answers, df_train, df_dev, df_test = load_data()
Y_dev = df_dev.sentiment.values
Y_test = df_test.sentiment.values

print("Answer to int mapping:")
for k, v in sorted(answer_mapping.items(), key=lambda kv: kv[1]):
    print(f"{k:<50}{v}")

Answer to int mapping:
I can't tell                                      -1
Negative                                          0
Positive                                          1
Neutral / author is just sharing information      2
Tweet not related to weather condition            3


First, let's take a look at our development set to get a sense of
what the tweets look like.

In [3]:
df_dev.head()

Unnamed: 0_level_0,tweet_id,tweet_text,sentiment
tweet_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
79197834,79197834,@mention not in sunny dover! haha,1
80059939,80059939,It is literally pissing it down in sideways ra...,0
79196441,79196441,"Dear perfect weather, thanks for the vest lunc...",1
84047300,84047300,RT @mention: I can't wait for the storm tonigh...,1
83255121,83255121,60 degrees. And its almost the end of may. Wis...,0


Now let's take a look at the crowd labels.
We'll convert these into labeling functions.

In [4]:
crowd_answers.head()

Unnamed: 0_level_0,worker_id,answer
tweet_id,Unnamed: 1_level_1,Unnamed: 2_level_1
82510997,18034918,1
82510997,7450342,1
82510997,18465660,1
82510997,17475684,0
82510997,14472526,1


## Writing Labeling Functions
Each crowd worker can be thought of as a single labeling function,
as each worker labels a subset of examples,
and may have errors or conflicting answers with other workers / labeling functions.
So we create one labeling function per worker.
We'll simply return the label the worker submitted for a given tweet, and abstain
if they didn't submit an answer for it.

### Crowd worker labeling functions

In [5]:
labels_by_annotator = crowd_answers.groupby("worker_id")
worker_dicts = {}
for worker_id in labels_by_annotator.groups:
    worker_df = labels_by_annotator.get_group(worker_id)[["answer"]]
    if len(worker_df) > 10:
        worker_dicts[worker_id] = dict(zip(worker_df.index, worker_df.answer))

print("Number of workers:", len(worker_dicts))

Number of workers: 68


In [6]:
from snorkel.labeling.lf import LabelingFunction


def f_pos(x, worker_dict):
    label = worker_dict.get(x.tweet_id)
    return 1 if label == 1 else -1


def f_neg(x, worker_dict):
    label = worker_dict.get(x.tweet_id)
    return 0 if label == 0 else -1


def get_worker_labeling_function(worker_id, f):
    worker_dict = worker_dicts[worker_id]
    name = f"worker_{worker_id}"
    return LabelingFunction(name, f=f, resources={"worker_dict": worker_dict})


worker_lfs_pos = [
    get_worker_labeling_function(worker_id, f_pos) for worker_id in worker_dicts
]
worker_lfs_neg = [
    get_worker_labeling_function(worker_id, f_neg) for worker_id in worker_dicts
]

Let's take a quick look at how well they do on the development set.

In [7]:
from snorkel.labeling.apply import PandasLFApplier

lfs = worker_lfs_pos + worker_lfs_neg

applier = PandasLFApplier(lfs)
L_train = applier.apply(df_train)
L_dev = applier.apply(df_dev)

  0%|          | 0/187 [00:00<?, ?it/s]

 45%|████▍     | 84/187 [00:00<00:00, 833.28it/s]

 90%|████████▉ | 168/187 [00:00<00:00, 832.72it/s]

100%|██████████| 187/187 [00:00<00:00, 822.20it/s]


  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:00<00:00, 822.64it/s]




In [8]:
from snorkel.labeling.analysis import LFAnalysis

LFAnalysis(L_dev, lfs).lf_summary(Y_dev).head(10)

  return np.nan_to_num(0.5 * (X.sum(axis=0) / (self.L != -1).sum(axis=0) + 1))


Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
worker_6340330,0,[1],0.04,0.04,0.04,2,0,1.0
worker_6344001,1,[1],0.04,0.04,0.04,2,0,1.0
worker_6346694,2,[1],0.12,0.12,0.1,5,1,0.833333
worker_6363996,3,[1],0.04,0.04,0.02,2,0,1.0
worker_6371053,4,[1],0.06,0.06,0.06,2,1,0.666667
worker_6453108,5,[1],0.04,0.04,0.02,2,0,1.0
worker_6737418,6,[1],0.06,0.06,0.06,2,1,0.666667
worker_7325249,7,[1],0.1,0.1,0.1,4,1,0.8
worker_7450342,8,[],0.0,0.0,0.0,0,0,0.0
worker_7860247,9,[1],0.16,0.16,0.14,8,0,1.0


So the crowd labels are quite good! But how much of our dev and training
sets do they cover?

In [9]:
print("Training set coverge:", LFAnalysis(L_train).label_coverage())
print("Dev set coverge:", LFAnalysis(L_dev).label_coverage())

Training set coverge: 0.5026737967914439
Dev set coverge: 0.5


### Additional labeling functions

We can mix the crowd worker labeling functions with labeling
functions of other types.
We'll use a few varied approaches and use the label model learn
how to combine their values.

In [10]:
from snorkel.labeling.lf import labeling_function
from snorkel.labeling.preprocess import preprocessor
from textblob import TextBlob


@preprocessor()
def textblob_polarity(x):
    scores = TextBlob(x.tweet_text)
    x.polarity = scores.polarity
    return x


textblob_polarity.memoize = True


@labeling_function(preprocessors=[textblob_polarity])
def polarity_positive(x):
    return 1 if x.polarity > 0.3 else -1


@labeling_function(preprocessors=[textblob_polarity])
def polarity_negative(x):
    return 0 if x.polarity < -0.25 else -1


@labeling_function(preprocessors=[textblob_polarity])
def polarity_negative_2(x):
    return 0 if x.polarity <= 0.3 else -1

### Applying labeling functions to the training set

In [11]:
from snorkel.labeling.apply import PandasLFApplier

text_lfs = [polarity_positive, polarity_negative, polarity_negative_2]
lfs = text_lfs + worker_lfs_pos + worker_lfs_neg

applier = PandasLFApplier(lfs)
L_train = applier.apply(df_train)
L_dev = applier.apply(df_dev)

  0%|          | 0/187 [00:00<?, ?it/s]

 17%|█▋        | 31/187 [00:00<00:00, 308.25it/s]

 42%|████▏     | 78/187 [00:00<00:00, 342.42it/s]

 67%|██████▋   | 125/187 [00:00<00:00, 371.26it/s]

 91%|█████████▏| 171/187 [00:00<00:00, 393.81it/s]

100%|██████████| 187/187 [00:00<00:00, 423.05it/s]


  0%|          | 0/50 [00:00<?, ?it/s]

 96%|█████████▌| 48/50 [00:00<00:00, 471.25it/s]

100%|██████████| 50/50 [00:00<00:00, 458.53it/s]




In [12]:
LFAnalysis(L_dev, lfs).lf_summary(Y_dev).head(10)

  return np.nan_to_num(0.5 * (X.sum(axis=0) / (self.L != -1).sum(axis=0) + 1))


Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
polarity_positive,0,[1],0.3,0.16,0.12,15,0,1.0
polarity_negative,1,[0],0.1,0.1,0.04,5,0,1.0
polarity_negative_2,2,[0],0.7,0.4,0.32,26,9,0.742857
worker_6340330,3,[1],0.04,0.04,0.04,2,0,1.0
worker_6344001,4,[1],0.04,0.04,0.04,2,0,1.0
worker_6346694,5,[1],0.12,0.12,0.1,5,1,0.833333
worker_6363996,6,[1],0.04,0.04,0.02,2,0,1.0
worker_6371053,7,[1],0.06,0.06,0.06,2,1,0.666667
worker_6453108,8,[1],0.04,0.04,0.02,2,0,1.0
worker_6737418,9,[1],0.06,0.06,0.06,2,1,0.666667


Using the text-based LFs, we've expanded coverage on both our training set
and dev set to 100%.
We'll now take these noisy and conflicting labels, and use the label model
to denoise and combine them.

In [13]:
print("Training set coverge:", LFAnalysis(L_train).label_coverage())
print("Dev set coverge:", LFAnalysis(L_dev).label_coverage())

Training set coverge: 1.0
Dev set coverge: 1.0


## Train Label Model And Generate Soft Labels

In [14]:
from snorkel.labeling.model.label_model import LabelModel

# Train label model.
label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_train, n_epochs=100, seed=123, log_freq=20, l2=0.1, lr=0.01)

Computing O...
Estimating \mu...
[0 epochs]: TRAIN:[loss=2.371]
[20 epochs]: TRAIN:[loss=0.568]


[40 epochs]: TRAIN:[loss=0.536]
[60 epochs]: TRAIN:[loss=0.522]
[80 epochs]: TRAIN:[loss=0.524]


Finished Training


As a spot-check for the quality of our label model, we'll score it on the dev set.

In [15]:
from snorkel.analysis.metrics import metric_score
from snorkel.analysis.utils import probs_to_preds

Y_dev_prob = label_model.predict_proba(L_dev)
Y_dev_pred = probs_to_preds(Y_dev_prob)

acc = metric_score(Y_dev, Y_dev_pred, probs=None, metric="accuracy")
print(f"Label Model Accuracy: {acc:.3f}")

Label Model Accuracy: 0.920


Look at that, we get very high accuracy on the development set.
This is due to the abundance of high quality crowd worker labels.
**Since we don't have these high quality crowdsourcing labels for the
test set or new incoming examples, we can't use the label model reliably
at inference time.**
In order to run inference on new incoming examples, we need to train a
discriminative model over the tweets themselves.
Let's generate a set of probabilistic labels for the training set.

In [16]:
Y_train_prob = label_model.predict_proba(L_train)

## Use Soft Labels to Train End Model

### Getting features from BERT
Since we have very limited training data, we cannot train a complex model like an LSTM with a lot of parameters. Instead, we use a pre-trained model, [BERT](https://github.com/google-research/bert), to generate embeddings for each our tweets, and treat the embedding values as features.

In [17]:
import numpy as np
import torch
from pytorch_transformers import BertModel, BertTokenizer

model = BertModel.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


def encode_text(text):
    input_ids = torch.tensor([tokenizer.encode(text)])
    return model(input_ids)[0].mean(1)[0].detach().numpy()


train_vectors = np.array(list(df_train.tweet_text.apply(encode_text).values))
test_vectors = np.array(list(df_test.tweet_text.apply(encode_text).values))

  0%|          | 0/313 [00:00<?, ?B/s]

100%|██████████| 313/313 [00:00<00:00, 258785.17B/s]




  0%|          | 0/440473133 [00:00<?, ?B/s]

  0%|          | 34816/440473133 [00:00<31:20, 234265.81B/s]

  0%|          | 208896/440473133 [00:00<23:24, 313382.57B/s]

  0%|          | 487424/440473133 [00:00<17:13, 425692.64B/s]

  0%|          | 1893376/440473133 [00:00<12:12, 598977.44B/s]

  1%|          | 3597312/440473133 [00:00<08:38, 842101.83B/s]

  2%|▏         | 6628352/440473133 [00:00<06:05, 1186339.29B/s]

  2%|▏         | 8332288/440473133 [00:00<04:23, 1642364.19B/s]

  3%|▎         | 11363328/440473133 [00:00<03:07, 2283612.85B/s]

  3%|▎         | 13114368/440473133 [00:01<02:18, 3077053.54B/s]

  4%|▎         | 16081920/440473133 [00:01<01:41, 4177933.82B/s]

  4%|▍         | 18040832/440473133 [00:01<01:17, 5418777.39B/s]

  5%|▍         | 20816896/440473133 [00:01<00:59, 7067233.57B/s]

  5%|▌         | 22878208/440473133 [00:01<00:48, 8668269.12B/s]

  6%|▌         | 25551872/440473133 [00:01<00:38, 10704942.59B/s]

  6%|▋         | 27659264/440473133 [00:01<00:33, 12276099.63B/s]

  7%|▋         | 30286848/440473133 [00:01<00:28, 14320707.91B/s]

  7%|▋         | 32414720/440473133 [00:01<00:26, 15418054.92B/s]

  8%|▊         | 35021824/440473133 [00:02<00:23, 17170187.52B/s]

  8%|▊         | 37159936/440473133 [00:02<00:22, 17652428.89B/s]

  9%|▉         | 39756800/440473133 [00:02<00:21, 19025812.02B/s]

 10%|▉         | 41899008/440473133 [00:02<00:21, 18948413.72B/s]

 10%|█         | 44491776/440473133 [00:02<00:19, 20089175.53B/s]

 11%|█         | 46637056/440473133 [00:02<00:20, 19674454.20B/s]

 11%|█         | 49226752/440473133 [00:02<00:18, 20655272.03B/s]

 12%|█▏        | 51373056/440473133 [00:02<00:19, 20043876.00B/s]

 12%|█▏        | 53945344/440473133 [00:02<00:18, 20924058.55B/s]

 13%|█▎        | 56090624/440473133 [00:03<00:19, 20207618.60B/s]

 13%|█▎        | 58680320/440473133 [00:03<00:18, 21092995.32B/s]

 14%|█▍        | 60829696/440473133 [00:03<00:18, 20324183.04B/s]

 14%|█▍        | 63415296/440473133 [00:03<00:17, 21174882.76B/s]

 15%|█▍        | 65565696/440473133 [00:03<00:18, 20369165.69B/s]

 15%|█▌        | 68150272/440473133 [00:03<00:17, 21211904.49B/s]

 16%|█▌        | 70301696/440473133 [00:03<00:18, 20372298.08B/s]

 17%|█▋        | 72885248/440473133 [00:03<00:17, 21223557.03B/s]

 17%|█▋        | 75036672/440473133 [00:03<00:17, 20394965.57B/s]

 18%|█▊        | 77620224/440473133 [00:04<00:17, 21243033.86B/s]

 18%|█▊        | 79772672/440473133 [00:04<00:17, 20409306.38B/s]

 19%|█▊        | 82355200/440473133 [00:04<00:16, 21238488.05B/s]

 19%|█▉        | 84506624/440473133 [00:04<00:17, 20417662.19B/s]

 20%|█▉        | 87073792/440473133 [00:04<00:16, 21219150.96B/s]

 20%|██        | 89222144/440473133 [00:04<00:17, 20379300.77B/s]

 21%|██        | 91808768/440473133 [00:04<00:16, 21240932.97B/s]

 21%|██▏       | 93960192/440473133 [00:04<00:16, 20408575.96B/s]

 22%|██▏       | 96543744/440473133 [00:04<00:16, 21249004.56B/s]

 22%|██▏       | 98696192/440473133 [00:05<00:16, 20406035.28B/s]

 23%|██▎       | 101278720/440473133 [00:05<00:15, 21242380.14B/s]

 23%|██▎       | 103430144/440473133 [00:05<00:16, 20426822.48B/s]

 24%|██▍       | 106013696/440473133 [00:05<00:15, 21246500.74B/s]

 25%|██▍       | 108165120/440473133 [00:05<00:16, 20425605.08B/s]

 25%|██▌       | 110748672/440473133 [00:05<00:15, 21256043.34B/s]

 26%|██▌       | 112901120/440473133 [00:05<00:16, 20431808.73B/s]

 26%|██▌       | 115483648/440473133 [00:05<00:15, 21257980.80B/s]

 27%|██▋       | 117636096/440473133 [00:06<00:15, 20411607.68B/s]

 27%|██▋       | 120202240/440473133 [00:06<00:15, 21230153.79B/s]

 28%|██▊       | 122352640/440473133 [00:06<00:15, 20375978.33B/s]

 28%|██▊       | 124937216/440473133 [00:06<00:14, 21249653.03B/s]

 29%|██▉       | 127090688/440473133 [00:06<00:15, 20364093.74B/s]

 29%|██▉       | 129672192/440473133 [00:06<00:14, 21261666.48B/s]

 30%|██▉       | 131828736/440473133 [00:06<00:15, 20383442.82B/s]

 31%|███       | 134407168/440473133 [00:06<00:14, 21240051.67B/s]

 31%|███       | 136560640/440473133 [00:06<00:14, 20391397.74B/s]

 32%|███▏      | 139142144/440473133 [00:07<00:14, 21243885.70B/s]

 32%|███▏      | 141295616/440473133 [00:07<00:14, 20401153.84B/s]

 33%|███▎      | 143877120/440473133 [00:07<00:13, 21251622.30B/s]

 33%|███▎      | 146030592/440473133 [00:07<00:14, 20406714.80B/s]

 34%|███▎      | 148595712/440473133 [00:07<00:13, 21209910.46B/s]

 34%|███▍      | 150744064/440473133 [00:07<00:14, 20393626.30B/s]

 35%|███▍      | 153330688/440473133 [00:07<00:13, 21231861.34B/s]

 35%|███▌      | 155481088/440473133 [00:07<00:14, 20352883.56B/s]

 36%|███▌      | 158065664/440473133 [00:07<00:13, 21241038.89B/s]

 36%|███▋      | 160219136/440473133 [00:08<00:13, 20317436.01B/s]

 37%|███▋      | 162800640/440473133 [00:08<00:13, 21256882.39B/s]

 37%|███▋      | 164958208/440473133 [00:08<00:13, 20322559.38B/s]

 38%|███▊      | 167535616/440473133 [00:08<00:12, 21275896.38B/s]

 39%|███▊      | 169697280/440473133 [00:08<00:13, 19767450.16B/s]

 39%|███▉      | 172352512/440473133 [00:08<00:12, 21306229.14B/s]

 40%|███▉      | 174548992/440473133 [00:08<00:13, 19968091.27B/s]

 40%|████      | 177098752/440473133 [00:08<00:12, 21357220.12B/s]

 41%|████      | 179305472/440473133 [00:08<00:13, 20070911.51B/s]

 41%|████▏     | 181868544/440473133 [00:09<00:12, 21467745.05B/s]

 42%|████▏     | 184088576/440473133 [00:09<00:12, 20189827.46B/s]

 42%|████▏     | 186606592/440473133 [00:09<00:13, 19459417.85B/s]

 43%|████▎     | 189621248/440473133 [00:09<00:11, 21617374.08B/s]

 44%|████▎     | 191895552/440473133 [00:09<00:12, 20215385.74B/s]

 44%|████▍     | 194438144/440473133 [00:09<00:11, 21503299.13B/s]

 45%|████▍     | 196679680/440473133 [00:09<00:12, 20278827.68B/s]

 45%|████▌     | 199189504/440473133 [00:09<00:11, 21516958.26B/s]

 46%|████▌     | 201416704/440473133 [00:09<00:11, 20274330.09B/s]

 46%|████▋     | 203908096/440473133 [00:10<00:11, 21470248.85B/s]

 47%|████▋     | 206121984/440473133 [00:10<00:11, 20213719.41B/s]

 47%|████▋     | 208626688/440473133 [00:10<00:10, 21454112.76B/s]

 48%|████▊     | 210836480/440473133 [00:10<00:11, 20174351.13B/s]

 48%|████▊     | 213361664/440473133 [00:10<00:10, 21461713.63B/s]

 49%|████▉     | 215573504/440473133 [00:10<00:11, 20122488.62B/s]

 50%|████▉     | 218140672/440473133 [00:10<00:10, 21517872.65B/s]

 50%|█████     | 220363776/440473133 [00:10<00:10, 20155260.44B/s]

 51%|█████     | 222897152/440473133 [00:11<00:11, 19484896.86B/s]

 51%|█████▏    | 225911808/440473133 [00:11<00:09, 21653530.26B/s]

 52%|█████▏    | 228190208/440473133 [00:11<00:10, 20054071.96B/s]

 52%|█████▏    | 230777856/440473133 [00:11<00:10, 19668981.87B/s]

 53%|█████▎    | 233792512/440473133 [00:11<00:09, 21775846.11B/s]

 54%|█████▎    | 236092416/440473133 [00:11<00:10, 20155092.28B/s]

 54%|█████▍    | 238674944/440473133 [00:11<00:10, 19691832.36B/s]

 55%|█████▍    | 241689600/440473133 [00:11<00:09, 21823434.07B/s]

 55%|█████▌    | 243999744/440473133 [00:12<00:09, 20226482.55B/s]

 56%|█████▌    | 246572032/440473133 [00:12<00:09, 19701189.73B/s]

 57%|█████▋    | 249570304/440473133 [00:12<00:08, 21832381.66B/s]

 57%|█████▋    | 251883520/440473133 [00:12<00:09, 20201781.55B/s]

 58%|█████▊    | 254452736/440473133 [00:12<00:09, 19715423.93B/s]

 58%|█████▊    | 257451008/440473133 [00:12<00:08, 21835740.34B/s]

 59%|█████▉    | 259764224/440473133 [00:12<00:08, 20192630.59B/s]

 60%|█████▉    | 262349824/440473133 [00:12<00:09, 19757261.73B/s]

 60%|██████    | 265348096/440473133 [00:13<00:08, 21866890.02B/s]

 61%|██████    | 267663360/440473133 [00:13<00:08, 20205326.93B/s]

 61%|██████▏   | 270230528/440473133 [00:13<00:08, 19743252.90B/s]

 62%|██████▏   | 273228800/440473133 [00:13<00:07, 21854363.23B/s]

 63%|██████▎   | 275543040/440473133 [00:13<00:08, 20200420.49B/s]

 63%|██████▎   | 278111232/440473133 [00:13<00:08, 19742102.35B/s]

 64%|██████▍   | 281125888/440473133 [00:13<00:07, 21867230.23B/s]

 64%|██████▍   | 283442176/440473133 [00:13<00:07, 20205067.01B/s]

 65%|██████▍   | 286008320/440473133 [00:14<00:07, 19720060.80B/s]

 66%|██████▌   | 289006592/440473133 [00:14<00:06, 21858683.26B/s]

 66%|██████▌   | 291323904/440473133 [00:14<00:07, 20203866.82B/s]

 67%|██████▋   | 293889024/440473133 [00:14<00:07, 19708337.23B/s]

 67%|██████▋   | 296903680/440473133 [00:14<00:06, 21869007.59B/s]

 68%|██████▊   | 299224064/440473133 [00:14<00:06, 20228821.99B/s]

 69%|██████▊   | 301786112/440473133 [00:14<00:07, 19717133.38B/s]

 69%|██████▉   | 304800768/440473133 [00:14<00:06, 21874748.89B/s]

 70%|██████▉   | 307122176/440473133 [00:15<00:06, 20228054.67B/s]

 70%|███████   | 309666816/440473133 [00:15<00:06, 19691548.53B/s]

 71%|███████   | 312681472/440473133 [00:15<00:05, 21864114.43B/s]

 72%|███████▏  | 315003904/440473133 [00:15<00:06, 20195597.66B/s]

 72%|███████▏  | 317563904/440473133 [00:15<00:06, 19690168.42B/s]

 73%|███████▎  | 320578560/440473133 [00:15<00:05, 21874213.22B/s]

 73%|███████▎  | 322903040/440473133 [00:15<00:05, 20201038.80B/s]

 74%|███████▍  | 325444608/440473133 [00:15<00:05, 19654175.79B/s]

 75%|███████▍  | 328459264/440473133 [00:16<00:05, 21857367.89B/s]

 75%|███████▌  | 330784768/440473133 [00:16<00:05, 20197997.92B/s]

 76%|███████▌  | 333341696/440473133 [00:16<00:05, 19660514.93B/s]

 76%|███████▋  | 336339968/440473133 [00:16<00:04, 21862836.72B/s]

 77%|███████▋  | 338667520/440473133 [00:16<00:05, 20178563.82B/s]

 77%|███████▋  | 341222400/440473133 [00:16<00:05, 19669269.96B/s]

 78%|███████▊  | 344237056/440473133 [00:16<00:04, 21896941.32B/s]

 79%|███████▊  | 346569728/440473133 [00:16<00:04, 20203563.37B/s]

 79%|███████▉  | 349119488/440473133 [00:17<00:04, 19682509.54B/s]

 80%|███████▉  | 352117760/440473133 [00:17<00:04, 21882613.20B/s]

 80%|████████  | 354447360/440473133 [00:17<00:04, 20194924.90B/s]

 81%|████████  | 357000192/440473133 [00:17<00:04, 19681457.47B/s]

 82%|████████▏ | 360014848/440473133 [00:17<00:03, 21903790.08B/s]

 82%|████████▏ | 362348544/440473133 [00:17<00:03, 20185640.23B/s]

 83%|████████▎ | 364880896/440473133 [00:17<00:03, 19654002.72B/s]

 84%|████████▎ | 367895552/440473133 [00:17<00:03, 21876277.09B/s]

 84%|████████▍ | 370227200/440473133 [00:18<00:03, 20180096.40B/s]

 85%|████████▍ | 372777984/440473133 [00:18<00:03, 19681243.84B/s]

 85%|████████▌ | 375792640/440473133 [00:18<00:02, 21904845.68B/s]

 86%|████████▌ | 378127360/440473133 [00:18<00:03, 20195868.65B/s]

 86%|████████▋ | 380658688/440473133 [00:18<00:03, 19642533.72B/s]

 87%|████████▋ | 383673344/440473133 [00:18<00:02, 21869506.21B/s]

 88%|████████▊ | 386004992/440473133 [00:18<00:02, 20160877.11B/s]

 88%|████████▊ | 388555776/440473133 [00:18<00:02, 19675529.34B/s]

 89%|████████▉ | 391570432/440473133 [00:19<00:02, 21903361.43B/s]

 89%|████████▉ | 393905152/440473133 [00:19<00:02, 20186854.91B/s]

 90%|█████████ | 396436480/440473133 [00:19<00:02, 19642242.41B/s]

 91%|█████████ | 399451136/440473133 [00:19<00:01, 21891528.10B/s]

 91%|█████████ | 401787904/440473133 [00:19<00:01, 20175011.55B/s]

 92%|█████████▏| 404333568/440473133 [00:19<00:01, 19666375.00B/s]

 92%|█████████▏| 407348224/440473133 [00:19<00:01, 21918579.37B/s]

 93%|█████████▎| 409688064/440473133 [00:19<00:01, 20181427.51B/s]

 94%|█████████▎| 412230656/440473133 [00:20<00:01, 19680032.44B/s]

 94%|█████████▍| 415228928/440473133 [00:20<00:01, 21906014.13B/s]

 95%|█████████▍| 417565696/440473133 [00:20<00:01, 20165171.66B/s]

 95%|█████████▌| 420111360/440473133 [00:20<00:01, 19668455.42B/s]

 96%|█████████▌| 423126016/440473133 [00:20<00:00, 21918461.78B/s]

 97%|█████████▋| 425465856/440473133 [00:20<00:00, 20188531.58B/s]

 97%|█████████▋| 427992064/440473133 [00:20<00:00, 19647384.42B/s]

 98%|█████████▊| 431006720/440473133 [00:20<00:00, 21893132.67B/s]

 98%|█████████▊| 433344512/440473133 [00:21<00:00, 20162011.65B/s]

 99%|█████████▉| 435889152/440473133 [00:21<00:00, 19671636.70B/s]

100%|█████████▉| 438887424/440473133 [00:21<00:00, 21894053.77B/s]

100%|██████████| 440473133/440473133 [00:21<00:00, 20641369.99B/s]




  0%|          | 0/231508 [00:00<?, ?B/s]

 15%|█▌        | 34816/231508 [00:00<00:00, 250451.73B/s]

 68%|██████▊   | 156672/231508 [00:00<00:00, 325735.08B/s]

100%|██████████| 231508/231508 [00:00<00:00, 911404.52B/s]




### Model on soft labels
Now, we train a simple logistic regression model on the BERT features, using labels
obtained from our label model.

In [18]:
from sklearn.linear_model import LogisticRegression

sklearn_model = LogisticRegression(solver="liblinear")
sklearn_model.fit(train_vectors, probs_to_preds(Y_train_prob))

print(f"Accuracy of trained model: {sklearn_model.score(test_vectors, Y_test)}")

Accuracy of trained model: 0.86
