# Snorkel Transformation Functions Tutorial

In this tutorial, we will walk through the process of using `Snorkel Transformation Functions (TFs)` to classify YouTube comments as `SPAM` or `HAM` (not spam). For more details on the task, check out the main labeling functions [tutorial](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/spam_tutorial.ipynb).
For an overview of Snorkel, visit [snorkel.org](http://snorkel.org).
You can also check out the [Snorkel API documentation](https://snorkel.readthedocs.io/).

For our task, we have access to a some labeled data.
We use **_Transformation Functions_** to perform data augmentation to get additional training data.

The tutorial is divided into four parts:
1. **Loading Data**: We load a [YouTube comments dataset](https://www.kaggle.com/goneee/youtube-spam-classifiedcomments) from Kaggle.
2. **Writing Transformation Functions**: We write Transformation Functions (TFs) that can be applied to training examples to generate new training examples.
3. **Applying Transformation Functions**: We apply a sequence of TFs to each training data point, using a random policy, to generate an augmented training set.
4. **Training An End Model**: We use the augmented training set to train an LSTM model for classifying new comments as `SPAM` or `HAM`.

### Data Splits in Snorkel

We split our data into 3 sets:
* **Training Set**: The largest split of the dataset. These are the examples used for training, and also the ones that transformation functions are applied on.
* **Validation Set**: A labeled set used to tune hyperparameters and/or perform early stopping while training the classifier.
* **Test Set**: A labeled set for final evaluation of our classifier. This set should only be used for final evaluation, _not_ tuning.

## 1. Loading Data

We load the Kaggle dataset and create Pandas DataFrame objects for each of the sets described above.
The two main columns in the DataFrames are:
* **`text`**: Raw text content of the comment
* **`label`**: Whether the comment is `SPAM` (1), `HAM` (0), or `UNKNOWN/ABSTAIN` (-1)

For more details, check out the labeling functions [tutorial](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/spam_tutorial.ipynb).

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

# Make sure we're running from the spam/ directory
if os.path.basename(os.getcwd()) == "snorkel-tutorials":
    os.chdir("spam")

In [3]:
from utils import load_spam_dataset

df_train, _, df_valid, df_test = load_spam_dataset(delete_train_labels=False)

# We pull out the label vectors for ease of use later
Y_valid = df_valid["label"].values
Y_train = df_train["label"].values
Y_test = df_test["label"].values

In [4]:
df_train.head()

Unnamed: 0,author,date,text,label,video
0,Alessandro leite,2014-11-05T22:21:36,pls http://www10.vakinha.com.br/VaquinhaE.aspx...,1,1
1,Salim Tayara,2014-11-02T14:33:30,"if your like drones, plz subscribe to Kamal Ta...",1,1
2,Phuc Ly,2014-01-20T15:27:47,go here to check the views :3﻿,0,1
3,DropShotSk8r,2014-01-19T04:27:18,"Came here to check the views, goodbye.﻿",0,1
4,css403,2014-11-07T14:25:48,"i am 2,126,492,636 viewer :D﻿",0,1


## 2. Writing Transformation Functions

Transformation Functions are functions that can be applied to a training example to create another valid training example. For example, for image classification problems, it is common to rotate or crop images in the training data to create new training inputs.

Our task involves processing text. Some common ways to augment text includes replacing words with their synonyms, or replacing names entities with other entities. Applying these operations to a comment shouldn't change whether it is `SPAM` or not.

Transformation functions in Snorkel are created with the `@transformation_function()` decorator, which wraps a function for taking a single data point and returning a transformed version of the data point.

We start with a simple transformation function that changes a random character in the text to simulate a typo.

In [5]:
import string
from snorkel.augmentation.tf import transformation_function


@transformation_function()
def change_character(x):
    idx = np.random.choice(range(len(x.text) - 1))
    char = np.random.choice(list(string.ascii_lowercase))
    x.text = x.text[:idx] + char + x.text[idx + 1 :]
    return x

### Adding `pre` mappers.
Some TFs rely on fields that aren't present in the raw data, but can be derived from it.
We can enrich our data (providing more fields for the TFs to refer to) using map functions specified in the `pre` field of the transformation_function decorator (similar to `preprocessor` used for Labeling Functions).

For example, we can use the fantastic NLP tool [spaCy](https://spacy.io/) to add lemmas, part-of-speech (pos) tags, etc. to each token.
Snorkel provides a prebuilt preprocessor for spaCy called `SpacyPreprocessor` which adds a new field to the
data point containing a [spaCy `Doc` object](https://spacy.io/api/doc).
For more info, see the [`SpacyPreprocessor` documentation](https://snorkel.readthedocs.io/en/master/source/snorkel.labeling.preprocess.html#snorkel.labeling.preprocess.nlp.SpacyPreprocessor).


In [6]:
# Download the spaCy english model
# If you see an error in the next cell, restart the kernel
! python -m spacy download en_core_web_sm



You should consider upgrading via the 'pip install --upgrade pip' command.[0m


[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_sm')


In [7]:
from snorkel.labeling.preprocess.nlp import SpacyPreprocessor

# The SpacyPreprocessor parses the text in text_field and
# stores the new enriched representation in doc_field
spacy = SpacyPreprocessor(text_field="text", doc_field="doc", memoize=True)

In [8]:
import numpy as np

# TFs for replacing a random named entity with a different entity of the same type.
@transformation_function(pre=[spacy])
def change_person(x):
    persons = [str(ent) for ent in x.doc.ents if ent.label_ == "PERSON"]
    if persons:
        to_replace = np.random.choice(persons)
        replacement = np.random.choice(["Bob", "Alice"])
        x.text = x.text.replace(to_replace, replacement)
        return x


@transformation_function(pre=[spacy])
def change_date(x):
    dates = [str(ent) for ent in x.doc.ents if ent.label_ == "DATE"]
    if dates:
        to_replace = np.random.choice(dates)
        replacement = np.random.choice(["31st December", "01/03/99"])
        x.text = x.text.replace(to_replace, replacement)
        return x


# Drop the last sentence of a multi-sentence comment, as this shouldn't change it's spam / ham nature.
@transformation_function(pre=[spacy])
def drop_last_sentence(x):
    sentences = [str(span) for span in x.doc.sents]
    if len(sentences) > 1:
        x.text = ". ".join(sentences[:-1])
        return x


# Remove a random stop word.
@transformation_function(pre=[spacy])
def drop_stop_word(x):
    words = [token.text for token in x.doc]
    stop_word_idxs = [i for i, token in enumerate(x.doc) if token.is_stop]
    if len(stop_word_idxs) < 2:
        return x
    to_drop = np.random.choice(stop_word_idxs[:-1])
    x.text = " ".join(words[:to_drop] + words[1 + to_drop :])
    return x


# Swap two nouns at random.
@transformation_function(pre=[spacy])
def swap_nouns(x):
    words = [token.text for token in x.doc]
    noun_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "NOUN"]
    if len(noun_idxs) < 3:
        return x
    idx1, idx2 = sorted(np.random.choice(noun_idxs[:-1], 2))
    x.text = " ".join(
        words[:idx1]
        + [words[idx2]]
        + words[1 + idx1 : idx2]
        + [words[idx1]]
        + words[1 + idx2 :]
    )
    return x

We add some transformation functions that use `wordnet` from [NLTK](https://www.nltk.org/) to replace different parts of speech with their synonyms.

In [9]:
import nltk
from nltk.corpus import wordnet as wn

nltk.download("wordnet")


def get_synonym(word, pos=None):
    synsets = wn.synsets(word, pos=pos)
    if not synsets:
        return word
    else:
        words = [lemma.name() for lemma in synsets[0].lemmas()]
        return words[0] if words else word


@transformation_function(pre=[spacy])
def replace_verb_with_synonym(x):
    words = [token.text for token in x.doc]
    verb_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "VERB"]
    if len(verb_idxs) < 2:
        return x
    to_replace = np.random.choice(verb_idxs[:-1])
    synonym = get_synonym(words[to_replace], pos="v")
    x.text = " ".join(words[:to_replace] + [synonym] + words[1 + to_replace :])
    return x


@transformation_function(pre=[spacy])
def replace_noun_with_synonym(x):
    words = [token.text for token in x.doc]
    noun_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "NOUN"]
    if len(noun_idxs) < 2:
        return x
    to_replace = np.random.choice(noun_idxs[:-1])
    synonym = get_synonym(words[to_replace], pos="n")
    x.text = " ".join(words[:to_replace] + [synonym] + words[1 + to_replace :])
    return x

[nltk_data] Downloading package wordnet to /home/ubuntu/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## 3. Applying Transformation Functions

To apply one or more TFs that we've written to a collection of data points, we use a `TFApplier`.
Because our data points are represented with a Pandas DataFrame in this tutorial, we use the `PandasTFApplier` class. In addition, we can apply multiple TFs in a sequence to each example. A `policy` is used to determine what sequence of TFs to apply to each example. In this case, we just use a `RandomPolicy` that picks 3 TFs at random per example. The `n_per_original` argument determines how many augmented examples to generate per original example.


In [10]:
from snorkel.augmentation.apply import PandasTFApplier
from snorkel.augmentation.policy import RandomPolicy

tfs = [
    change_character,
    change_person,
    change_date,
    drop_last_sentence,
    drop_stop_word,
    swap_nouns,
    replace_verb_with_synonym,
    replace_noun_with_synonym,
]

policy = RandomPolicy(len(tfs), sequence_length=3, n_per_original=2)
tf_applier = PandasTFApplier(tfs, policy)
df_train_augmented = tf_applier.apply(df_train).infer_objects()
Y_train_augmented = df_train_augmented["label"].values

  0%|          | 0/1586 [00:00<?, ?it/s]

  0%|          | 1/1586 [00:01<33:13,  1.26s/it]

  0%|          | 6/1586 [00:01<23:20,  1.13it/s]

  1%|          | 12/1586 [00:01<16:27,  1.59it/s]

  1%|          | 16/1586 [00:01<11:45,  2.23it/s]

  1%|▏         | 21/1586 [00:01<08:22,  3.12it/s]

  2%|▏         | 25/1586 [00:01<06:04,  4.28it/s]

  2%|▏         | 29/1586 [00:01<04:27,  5.82it/s]

  2%|▏         | 33/1586 [00:02<03:24,  7.60it/s]

  2%|▏         | 39/1586 [00:02<02:31, 10.22it/s]

  3%|▎         | 43/1586 [00:02<01:59, 12.93it/s]

  3%|▎         | 47/1586 [00:02<01:36, 15.97it/s]

  3%|▎         | 51/1586 [00:02<01:18, 19.44it/s]

  3%|▎         | 55/1586 [00:02<01:06, 22.92it/s]

  4%|▍         | 60/1586 [00:02<00:56, 27.19it/s]

  4%|▍         | 65/1586 [00:02<00:50, 30.28it/s]

  4%|▍         | 71/1586 [00:03<00:43, 34.71it/s]

  5%|▍         | 78/1586 [00:03<00:38, 39.45it/s]

  5%|▌         | 83/1586 [00:03<00:37, 39.74it/s]

  6%|▌         | 88/1586 [00:03<00:37, 39.89it/s]

  6%|▌         | 93/1586 [00:03<00:35, 42.15it/s]

  6%|▌         | 98/1586 [00:03<00:33, 44.20it/s]

  6%|▋         | 103/1586 [00:03<00:36, 40.72it/s]

  7%|▋         | 108/1586 [00:03<00:39, 37.38it/s]

  7%|▋         | 113/1586 [00:04<00:38, 38.61it/s]

  7%|▋         | 118/1586 [00:04<00:36, 39.74it/s]

  8%|▊         | 123/1586 [00:04<00:35, 41.00it/s]

  8%|▊         | 129/1586 [00:04<00:33, 44.03it/s]

  9%|▊         | 135/1586 [00:04<00:30, 47.15it/s]

  9%|▉         | 140/1586 [00:04<00:32, 44.22it/s]

  9%|▉         | 146/1586 [00:04<00:30, 47.18it/s]

 10%|▉         | 152/1586 [00:04<00:30, 47.42it/s]

 10%|▉         | 158/1586 [00:04<00:29, 48.11it/s]

 10%|█         | 164/1586 [00:05<00:30, 47.17it/s]

 11%|█         | 169/1586 [00:05<00:31, 44.68it/s]

 11%|█         | 174/1586 [00:05<00:37, 37.31it/s]

 11%|█         | 178/1586 [00:05<00:44, 31.30it/s]

 11%|█▏        | 182/1586 [00:05<00:42, 32.83it/s]

 12%|█▏        | 187/1586 [00:05<00:40, 34.21it/s]

 12%|█▏        | 191/1586 [00:05<00:41, 33.36it/s]

 12%|█▏        | 197/1586 [00:06<00:36, 37.73it/s]

 13%|█▎        | 202/1586 [00:06<00:35, 39.04it/s]

 13%|█▎        | 208/1586 [00:06<00:32, 41.89it/s]

 13%|█▎        | 213/1586 [00:06<00:34, 39.84it/s]

 14%|█▍        | 219/1586 [00:06<00:31, 43.22it/s]

 14%|█▍        | 224/1586 [00:06<00:31, 43.49it/s]

 15%|█▍        | 230/1586 [00:06<00:29, 45.93it/s]

 15%|█▍        | 235/1586 [00:06<00:30, 45.01it/s]

 15%|█▌        | 241/1586 [00:07<00:31, 43.14it/s]

 16%|█▌        | 247/1586 [00:07<00:30, 44.34it/s]

 16%|█▌        | 252/1586 [00:07<00:34, 38.79it/s]

 16%|█▌        | 257/1586 [00:07<00:32, 40.39it/s]

 17%|█▋        | 262/1586 [00:07<00:33, 39.44it/s]

 17%|█▋        | 267/1586 [00:07<00:31, 42.06it/s]

 17%|█▋        | 272/1586 [00:07<00:30, 43.58it/s]

 17%|█▋        | 277/1586 [00:07<00:29, 43.91it/s]

 18%|█▊        | 282/1586 [00:08<00:32, 40.72it/s]

 18%|█▊        | 287/1586 [00:08<00:33, 39.33it/s]

 18%|█▊        | 292/1586 [00:08<00:31, 40.81it/s]

 19%|█▊        | 297/1586 [00:08<00:35, 36.36it/s]

 19%|█▉        | 301/1586 [00:08<00:36, 35.14it/s]

 19%|█▉        | 306/1586 [00:08<00:34, 37.31it/s]

 20%|█▉        | 312/1586 [00:08<00:32, 39.35it/s]

 20%|██        | 318/1586 [00:08<00:28, 43.86it/s]

 20%|██        | 323/1586 [00:09<00:28, 44.83it/s]

 21%|██        | 328/1586 [00:09<00:28, 44.52it/s]

 21%|██        | 333/1586 [00:09<00:34, 36.25it/s]

 21%|██▏       | 339/1586 [00:09<00:30, 40.40it/s]

 22%|██▏       | 344/1586 [00:09<00:29, 42.68it/s]

 22%|██▏       | 349/1586 [00:09<00:35, 34.64it/s]

 22%|██▏       | 354/1586 [00:09<00:33, 37.03it/s]

 23%|██▎       | 360/1586 [00:09<00:31, 38.49it/s]

 23%|██▎       | 365/1586 [00:10<00:32, 37.75it/s]

 23%|██▎       | 369/1586 [00:10<00:35, 34.65it/s]

 24%|██▎       | 374/1586 [00:10<00:32, 37.21it/s]

 24%|██▍       | 379/1586 [00:10<00:30, 40.01it/s]

 24%|██▍       | 386/1586 [00:10<00:26, 45.85it/s]

 25%|██▍       | 392/1586 [00:10<00:29, 40.54it/s]

 25%|██▌       | 397/1586 [00:11<00:38, 30.53it/s]

 25%|██▌       | 402/1586 [00:11<00:38, 30.83it/s]

 26%|██▌       | 408/1586 [00:11<00:33, 35.64it/s]

 26%|██▌       | 413/1586 [00:11<00:30, 38.28it/s]

 26%|██▋       | 418/1586 [00:11<00:30, 38.80it/s]

 27%|██▋       | 424/1586 [00:11<00:26, 43.32it/s]

 27%|██▋       | 429/1586 [00:11<00:25, 45.07it/s]

 27%|██▋       | 434/1586 [00:11<00:30, 37.26it/s]

 28%|██▊       | 439/1586 [00:12<00:28, 40.01it/s]

 28%|██▊       | 444/1586 [00:12<00:27, 41.43it/s]

 28%|██▊       | 449/1586 [00:12<00:26, 42.94it/s]

 29%|██▊       | 454/1586 [00:12<00:26, 42.00it/s]

 29%|██▉       | 459/1586 [00:12<00:30, 36.56it/s]

 29%|██▉       | 464/1586 [00:12<00:31, 35.54it/s]

 30%|██▉       | 470/1586 [00:12<00:28, 39.54it/s]

 30%|██▉       | 475/1586 [00:12<00:31, 35.32it/s]

 30%|███       | 479/1586 [00:13<00:33, 32.89it/s]

 30%|███       | 483/1586 [00:13<00:32, 34.42it/s]

 31%|███       | 487/1586 [00:13<00:31, 34.48it/s]

 31%|███       | 491/1586 [00:13<00:31, 35.29it/s]

 31%|███       | 495/1586 [00:13<00:33, 32.99it/s]

 32%|███▏      | 500/1586 [00:13<00:32, 33.76it/s]

 32%|███▏      | 506/1586 [00:13<00:28, 38.29it/s]

 32%|███▏      | 512/1586 [00:13<00:27, 38.48it/s]

 33%|███▎      | 517/1586 [00:14<00:29, 36.75it/s]

 33%|███▎      | 522/1586 [00:14<00:27, 38.71it/s]

 33%|███▎      | 526/1586 [00:14<00:29, 36.32it/s]

 33%|███▎      | 530/1586 [00:14<00:32, 32.66it/s]

 34%|███▎      | 535/1586 [00:14<00:28, 36.33it/s]

 34%|███▍      | 539/1586 [00:14<00:29, 35.35it/s]

 34%|███▍      | 543/1586 [00:14<00:29, 35.20it/s]

 34%|███▍      | 547/1586 [00:14<00:29, 35.26it/s]

 35%|███▍      | 551/1586 [00:15<00:28, 36.50it/s]

 35%|███▌      | 558/1586 [00:15<00:25, 40.77it/s]

 36%|███▌      | 564/1586 [00:15<00:23, 43.26it/s]

 36%|███▌      | 570/1586 [00:15<00:21, 46.89it/s]

 36%|███▋      | 575/1586 [00:15<00:21, 46.79it/s]

 37%|███▋      | 581/1586 [00:15<00:21, 47.08it/s]

 37%|███▋      | 586/1586 [00:15<00:22, 44.62it/s]

 37%|███▋      | 591/1586 [00:15<00:22, 44.41it/s]

 38%|███▊      | 596/1586 [00:16<00:23, 42.95it/s]

 38%|███▊      | 602/1586 [00:16<00:22, 43.95it/s]

 38%|███▊      | 607/1586 [00:16<00:21, 45.24it/s]

 39%|███▊      | 612/1586 [00:16<00:23, 40.87it/s]

 39%|███▉      | 617/1586 [00:16<00:23, 40.53it/s]

 39%|███▉      | 622/1586 [00:16<00:24, 39.69it/s]

 40%|███▉      | 627/1586 [00:16<00:23, 40.92it/s]

 40%|███▉      | 632/1586 [00:16<00:24, 38.67it/s]

 40%|████      | 637/1586 [00:17<00:23, 39.95it/s]

 41%|████      | 643/1586 [00:17<00:21, 43.04it/s]

 41%|████      | 648/1586 [00:17<00:22, 42.01it/s]

 41%|████      | 654/1586 [00:17<00:24, 38.61it/s]

 42%|████▏     | 659/1586 [00:17<00:22, 40.45it/s]

 42%|████▏     | 665/1586 [00:17<00:21, 42.41it/s]

 42%|████▏     | 670/1586 [00:17<00:21, 41.67it/s]

 43%|████▎     | 675/1586 [00:17<00:20, 43.82it/s]

 43%|████▎     | 680/1586 [00:18<00:22, 39.59it/s]

 43%|████▎     | 685/1586 [00:18<00:23, 38.50it/s]

 43%|████▎     | 689/1586 [00:18<00:23, 38.72it/s]

 44%|████▎     | 693/1586 [00:18<00:25, 34.88it/s]

 44%|████▍     | 699/1586 [00:18<00:22, 39.44it/s]

 44%|████▍     | 705/1586 [00:18<00:20, 43.53it/s]

 45%|████▍     | 711/1586 [00:18<00:18, 46.82it/s]

 45%|████▌     | 717/1586 [00:18<00:18, 47.10it/s]

 46%|████▌     | 722/1586 [00:19<00:18, 46.18it/s]

 46%|████▌     | 728/1586 [00:19<00:17, 48.49it/s]

 46%|████▋     | 735/1586 [00:19<00:16, 51.25it/s]

 47%|████▋     | 741/1586 [00:19<00:18, 46.79it/s]

 47%|████▋     | 746/1586 [00:19<00:19, 43.98it/s]

 47%|████▋     | 752/1586 [00:19<00:18, 46.12it/s]

 48%|████▊     | 760/1586 [00:19<00:16, 50.44it/s]

 48%|████▊     | 768/1586 [00:19<00:14, 56.22it/s]

 49%|████▉     | 774/1586 [00:19<00:14, 56.67it/s]

 49%|████▉     | 780/1586 [00:20<00:14, 57.32it/s]

 50%|████▉     | 786/1586 [00:20<00:13, 57.98it/s]

 50%|████▉     | 792/1586 [00:20<00:15, 51.07it/s]

 50%|█████     | 798/1586 [00:20<00:17, 46.14it/s]

 51%|█████     | 804/1586 [00:20<00:16, 48.43it/s]

 51%|█████     | 811/1586 [00:20<00:14, 52.59it/s]

 52%|█████▏    | 817/1586 [00:20<00:14, 52.96it/s]

 52%|█████▏    | 824/1586 [00:20<00:13, 55.62it/s]

 52%|█████▏    | 830/1586 [00:21<00:14, 52.42it/s]

 53%|█████▎    | 837/1586 [00:21<00:13, 55.04it/s]

 53%|█████▎    | 843/1586 [00:21<00:13, 53.99it/s]

 54%|█████▎    | 849/1586 [00:21<00:14, 51.38it/s]

 54%|█████▍    | 856/1586 [00:21<00:13, 54.80it/s]

 55%|█████▍    | 865/1586 [00:21<00:12, 57.30it/s]

 55%|█████▍    | 871/1586 [00:21<00:13, 51.19it/s]

 55%|█████▌    | 880/1586 [00:21<00:12, 57.22it/s]

 56%|█████▌    | 887/1586 [00:22<00:13, 53.01it/s]

 56%|█████▋    | 893/1586 [00:22<00:12, 53.75it/s]

 57%|█████▋    | 899/1586 [00:22<00:12, 53.98it/s]

 57%|█████▋    | 905/1586 [00:22<00:12, 52.79it/s]

 58%|█████▊    | 912/1586 [00:22<00:12, 53.52it/s]

 58%|█████▊    | 918/1586 [00:22<00:12, 55.03it/s]

 58%|█████▊    | 925/1586 [00:22<00:11, 58.45it/s]

 59%|█████▊    | 931/1586 [00:23<00:16, 40.79it/s]

 59%|█████▉    | 936/1586 [00:23<00:16, 40.22it/s]

 59%|█████▉    | 943/1586 [00:23<00:14, 45.83it/s]

 60%|█████▉    | 949/1586 [00:23<00:13, 48.99it/s]

 60%|██████    | 956/1586 [00:23<00:12, 52.04it/s]

 61%|██████    | 962/1586 [00:23<00:12, 50.10it/s]

 61%|██████    | 968/1586 [00:23<00:12, 50.59it/s]

 61%|██████▏   | 974/1586 [00:23<00:14, 43.00it/s]

 62%|██████▏   | 981/1586 [00:24<00:12, 46.96it/s]

 62%|██████▏   | 987/1586 [00:24<00:12, 49.44it/s]

 63%|██████▎   | 993/1586 [00:24<00:11, 51.61it/s]

 63%|██████▎   | 1000/1586 [00:24<00:10, 54.77it/s]

 64%|██████▎   | 1009/1586 [00:24<00:09, 61.52it/s]

 64%|██████▍   | 1016/1586 [00:24<00:09, 57.22it/s]

 65%|██████▍   | 1023/1586 [00:24<00:10, 54.40it/s]

 65%|██████▌   | 1031/1586 [00:24<00:10, 50.97it/s]

 65%|██████▌   | 1038/1586 [00:25<00:10, 53.81it/s]

 66%|██████▌   | 1044/1586 [00:25<00:11, 48.10it/s]

 66%|██████▌   | 1050/1586 [00:25<00:10, 49.03it/s]

 67%|██████▋   | 1057/1586 [00:25<00:10, 52.26it/s]

 67%|██████▋   | 1064/1586 [00:25<00:09, 54.71it/s]

 67%|██████▋   | 1070/1586 [00:25<00:09, 53.76it/s]

 68%|██████▊   | 1076/1586 [00:25<00:10, 49.81it/s]

 68%|██████▊   | 1082/1586 [00:25<00:10, 49.54it/s]

 69%|██████▊   | 1088/1586 [00:26<00:10, 47.89it/s]

 69%|██████▉   | 1094/1586 [00:26<00:09, 50.06it/s]

 69%|██████▉   | 1100/1586 [00:26<00:10, 46.95it/s]

 70%|██████▉   | 1105/1586 [00:26<00:10, 47.33it/s]

 70%|██████▉   | 1110/1586 [00:26<00:11, 40.44it/s]

 70%|███████   | 1115/1586 [00:26<00:11, 42.72it/s]

 71%|███████   | 1121/1586 [00:26<00:10, 42.77it/s]

 71%|███████   | 1126/1586 [00:26<00:10, 44.21it/s]

 72%|███████▏  | 1135/1586 [00:27<00:09, 49.84it/s]

 72%|███████▏  | 1141/1586 [00:27<00:10, 42.95it/s]

 72%|███████▏  | 1146/1586 [00:27<00:10, 41.19it/s]

 73%|███████▎  | 1151/1586 [00:27<00:13, 33.35it/s]

 73%|███████▎  | 1155/1586 [00:27<00:12, 33.88it/s]

 73%|███████▎  | 1159/1586 [00:27<00:13, 32.41it/s]

 73%|███████▎  | 1163/1586 [00:28<00:15, 27.20it/s]

 74%|███████▎  | 1169/1586 [00:28<00:13, 31.68it/s]

 74%|███████▍  | 1173/1586 [00:28<00:12, 32.56it/s]

 74%|███████▍  | 1177/1586 [00:28<00:14, 27.49it/s]

 74%|███████▍  | 1181/1586 [00:28<00:13, 29.96it/s]

 75%|███████▍  | 1185/1586 [00:28<00:14, 27.88it/s]

 75%|███████▌  | 1190/1586 [00:28<00:12, 31.68it/s]

 75%|███████▌  | 1194/1586 [00:28<00:14, 27.72it/s]

 76%|███████▌  | 1198/1586 [00:29<00:14, 27.52it/s]

 76%|███████▌  | 1201/1586 [00:29<00:14, 27.22it/s]

 76%|███████▌  | 1204/1586 [00:29<00:16, 23.57it/s]

 76%|███████▌  | 1208/1586 [00:29<00:14, 26.44it/s]

 76%|███████▋  | 1211/1586 [00:29<00:15, 24.81it/s]

 77%|███████▋  | 1216/1586 [00:29<00:13, 28.06it/s]

 77%|███████▋  | 1223/1586 [00:29<00:10, 33.18it/s]

 77%|███████▋  | 1227/1586 [00:30<00:11, 31.40it/s]

 78%|███████▊  | 1232/1586 [00:30<00:10, 33.15it/s]

 78%|███████▊  | 1236/1586 [00:30<00:11, 29.26it/s]

 78%|███████▊  | 1240/1586 [00:30<00:12, 27.06it/s]

 78%|███████▊  | 1243/1586 [00:30<00:13, 26.31it/s]

 79%|███████▊  | 1246/1586 [00:30<00:12, 27.08it/s]

 79%|███████▉  | 1251/1586 [00:30<00:11, 29.06it/s]

 79%|███████▉  | 1255/1586 [00:31<00:12, 26.31it/s]

 79%|███████▉  | 1259/1586 [00:31<00:12, 26.13it/s]

 80%|███████▉  | 1264/1586 [00:31<00:10, 30.05it/s]

 80%|███████▉  | 1268/1586 [00:31<00:11, 28.33it/s]

 80%|████████  | 1272/1586 [00:31<00:10, 29.56it/s]

 80%|████████  | 1276/1586 [00:31<00:11, 28.03it/s]

 81%|████████  | 1280/1586 [00:31<00:10, 29.69it/s]

 81%|████████  | 1285/1586 [00:32<00:09, 31.23it/s]

 81%|████████▏ | 1289/1586 [00:32<00:11, 25.36it/s]

 82%|████████▏ | 1293/1586 [00:32<00:11, 24.91it/s]

 82%|████████▏ | 1296/1586 [00:32<00:12, 23.83it/s]

 82%|████████▏ | 1299/1586 [00:32<00:12, 22.40it/s]

 82%|████████▏ | 1302/1586 [00:32<00:12, 23.34it/s]

 82%|████████▏ | 1305/1586 [00:32<00:11, 24.45it/s]

 82%|████████▏ | 1308/1586 [00:33<00:11, 23.88it/s]

 83%|████████▎ | 1313/1586 [00:33<00:10, 26.35it/s]

 83%|████████▎ | 1317/1586 [00:33<00:09, 27.61it/s]

 83%|████████▎ | 1322/1586 [00:33<00:08, 31.37it/s]

 84%|████████▎ | 1327/1586 [00:33<00:07, 32.45it/s]

 84%|████████▍ | 1331/1586 [00:33<00:08, 31.62it/s]

 84%|████████▍ | 1335/1586 [00:33<00:09, 27.10it/s]

 84%|████████▍ | 1338/1586 [00:34<00:09, 27.17it/s]

 85%|████████▍ | 1342/1586 [00:34<00:08, 29.84it/s]

 85%|████████▍ | 1346/1586 [00:34<00:08, 28.39it/s]

 85%|████████▌ | 1351/1586 [00:34<00:08, 28.66it/s]

 85%|████████▌ | 1354/1586 [00:34<00:08, 28.04it/s]

 86%|████████▌ | 1358/1586 [00:34<00:07, 29.52it/s]

 86%|████████▌ | 1363/1586 [00:34<00:07, 31.07it/s]

 86%|████████▌ | 1367/1586 [00:35<00:08, 25.84it/s]

 86%|████████▋ | 1370/1586 [00:35<00:08, 26.34it/s]

 87%|████████▋ | 1375/1586 [00:35<00:07, 27.58it/s]

 87%|████████▋ | 1378/1586 [00:35<00:07, 27.36it/s]

 87%|████████▋ | 1381/1586 [00:35<00:07, 27.01it/s]

 87%|████████▋ | 1385/1586 [00:35<00:06, 29.63it/s]

 88%|████████▊ | 1389/1586 [00:35<00:07, 27.73it/s]

 88%|████████▊ | 1393/1586 [00:35<00:06, 29.50it/s]

 88%|████████▊ | 1397/1586 [00:36<00:06, 30.69it/s]

 88%|████████▊ | 1403/1586 [00:36<00:05, 34.70it/s]

 89%|████████▊ | 1407/1586 [00:36<00:05, 30.42it/s]

 89%|████████▉ | 1414/1586 [00:36<00:04, 34.68it/s]

 89%|████████▉ | 1418/1586 [00:36<00:05, 28.59it/s]

 90%|████████▉ | 1422/1586 [00:36<00:05, 29.85it/s]

 90%|████████▉ | 1426/1586 [00:36<00:05, 27.25it/s]

 90%|█████████ | 1430/1586 [00:37<00:05, 29.73it/s]

 90%|█████████ | 1434/1586 [00:37<00:05, 26.34it/s]

 91%|█████████ | 1438/1586 [00:37<00:05, 28.74it/s]

 91%|█████████ | 1442/1586 [00:37<00:06, 22.57it/s]

 91%|█████████ | 1445/1586 [00:37<00:07, 18.47it/s]

 91%|█████████▏| 1448/1586 [00:38<00:06, 20.63it/s]

 92%|█████████▏| 1453/1586 [00:38<00:05, 24.14it/s]

 92%|█████████▏| 1457/1586 [00:38<00:05, 24.00it/s]

 92%|█████████▏| 1460/1586 [00:38<00:05, 22.29it/s]

 92%|█████████▏| 1465/1586 [00:38<00:04, 26.70it/s]

 93%|█████████▎| 1469/1586 [00:38<00:04, 27.38it/s]

 93%|█████████▎| 1473/1586 [00:38<00:04, 27.96it/s]

 93%|█████████▎| 1477/1586 [00:39<00:04, 26.40it/s]

 93%|█████████▎| 1480/1586 [00:39<00:04, 23.08it/s]

 94%|█████████▎| 1485/1586 [00:39<00:03, 26.05it/s]

 94%|█████████▍| 1488/1586 [00:39<00:04, 23.09it/s]

 94%|█████████▍| 1491/1586 [00:39<00:04, 21.46it/s]

 94%|█████████▍| 1494/1586 [00:39<00:04, 21.26it/s]

 95%|█████████▍| 1499/1586 [00:39<00:03, 24.45it/s]

 95%|█████████▍| 1502/1586 [00:40<00:03, 25.63it/s]

 95%|█████████▌| 1507/1586 [00:40<00:02, 30.01it/s]

 95%|█████████▌| 1511/1586 [00:40<00:02, 27.22it/s]

 96%|█████████▌| 1515/1586 [00:40<00:03, 23.34it/s]

 96%|█████████▌| 1518/1586 [00:40<00:03, 20.68it/s]

 96%|█████████▌| 1522/1586 [00:40<00:02, 22.72it/s]

 96%|█████████▋| 1528/1586 [00:41<00:02, 24.69it/s]

 97%|█████████▋| 1531/1586 [00:41<00:02, 25.35it/s]

 97%|█████████▋| 1534/1586 [00:41<00:02, 23.28it/s]

 97%|█████████▋| 1538/1586 [00:41<00:01, 25.00it/s]

 97%|█████████▋| 1541/1586 [00:41<00:01, 24.67it/s]

 97%|█████████▋| 1544/1586 [00:41<00:01, 23.62it/s]

 98%|█████████▊| 1547/1586 [00:41<00:01, 21.92it/s]

 98%|█████████▊| 1550/1586 [00:42<00:01, 21.52it/s]

 98%|█████████▊| 1553/1586 [00:42<00:01, 17.85it/s]

 98%|█████████▊| 1555/1586 [00:42<00:01, 16.61it/s]

 98%|█████████▊| 1558/1586 [00:42<00:01, 18.74it/s]

 98%|█████████▊| 1561/1586 [00:42<00:01, 17.38it/s]

 99%|█████████▊| 1565/1586 [00:42<00:01, 20.21it/s]

 99%|█████████▉| 1570/1586 [00:42<00:00, 23.29it/s]

 99%|█████████▉| 1574/1586 [00:43<00:00, 26.27it/s]

 99%|█████████▉| 1578/1586 [00:43<00:00, 27.41it/s]

100%|█████████▉| 1583/1586 [00:43<00:00, 31.49it/s]

100%|██████████| 1586/1586 [00:43<00:00, 36.47it/s]




In [11]:
print(f"Original training set size: {len(df_train)}")
print(f"Augmented training set size: {len(df_train_augmented)}")

Original training set size: 1586
Augmented training set size: 4684


We have nearly tripled our dataset using TFs! Note that despite `n_per_original` being set to 2, our dataset may not exactly triple in size, because some TFs keep the example unchanged (e.g. `change_person` when applied to a sentence with no persons).

## 4. Training an End Model

Our final step is to use the augmented data to train a model. We train an LSTM (Long Short Term Memory) model, which is a commonly used architecture for text processing tasks.

In [12]:
import tensorflow as tf


def train_and_test(
    train_set, train_labels, num_buckets=30000, embed_dim=16, rnn_state_size=64
):
    def map_pad_or_truncate(string, max_length=30):
        ids = tf.keras.preprocessing.text.hashing_trick(
            string, n=num_buckets, hash_function="md5"
        )
        return ids[:max_length] + [0] * (max_length - len(ids))

    train_tokens = np.array(list(map(map_pad_or_truncate, train_set.text)))
    lstm_model = tf.keras.Sequential()
    lstm_model.add(tf.keras.layers.Embedding(num_buckets, embed_dim))
    lstm_model.add(tf.keras.layers.LSTM(rnn_state_size, activation=tf.nn.relu))
    lstm_model.add(tf.keras.layers.Dense(1, activation=tf.nn.sigmoid))
    lstm_model.compile("Adagrad", "binary_crossentropy", metrics=["accuracy"])

    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor="val_acc", patience=10, verbose=1, restore_best_weights=True
    )

    valid_tokens = np.array(list(map(map_pad_or_truncate, df_valid.text)))

    lstm_model.fit(
        train_tokens,
        train_labels,
        epochs=50,
        validation_data=(valid_tokens, Y_valid),
        callbacks=[early_stopping],
        verbose=0,
    )

    test_tokens = np.array(list(map(map_pad_or_truncate, df_test.text)))
    test_probs = lstm_model.predict(test_tokens)
    test_preds = test_probs[:, 0] > 0.5
    return (test_preds == Y_test).mean()


test_accuracy_original = train_and_test(df_train, Y_train)
test_accuracy_augmented = train_and_test(df_train_augmented, Y_train_augmented)

print(f"Test Accuracy when training on original dataset: {test_accuracy_original}")
print(f"Test Accuracy when training on augmented dataset: {test_accuracy_augmented}")

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
W0729 20:40:01.272500 140380536407872 deprecation.py:506] From /home/ubuntu/snorkel-tutorials/.tox/spam/lib/python3.6/site-packages/tensorflow/python/keras/initializers.py:119: calling RandomUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


W0729 20:40:01.288331 140380536407872 deprecation.py:506] From /home/ubuntu/snorkel-tutorials/.tox/spam/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


W0729 20:40:01.475184 140380536407872 deprecation.py:323] From /home/ubuntu/snorkel-tutorials/.tox/spam/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


W0729 20:40:02.122289 140380536407872 deprecation.py:506] From /home/ubuntu/snorkel-tutorials/.tox/spam/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/adagrad.py:105: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Restoring model weights from the end of the best epoch.
Epoch 00021: early stopping


Restoring model weights from the end of the best epoch.
Epoch 00013: early stopping


Test Accuracy when training on original dataset: 0.936
Test Accuracy when training on augmented dataset: 0.912
