# Snorkel Transformation Functions Tutorial

In this tutorial, we will walk through the process of using Snorkel Transformation Functions (TFs) to classify YouTube comments as `SPAM` or `HAM` (not spam). For more details on the task, check out the main labeling functions [tutorial](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/01_spam_tutorial.ipynb).
For an overview of Snorkel, visit [snorkel.org](http://snorkel.org).
You can also check out the [Snorkel API documentation](https://snorkel.readthedocs.io/).

For our task, we have access to some labeled YouTube comments for training. We generate additional data by transforming the labeled comments using **_Transformation Functions_**.

The tutorial is divided into four parts:
1. **Loading Data**: We load a [YouTube comments dataset](https://www.kaggle.com/goneee/youtube-spam-classifiedcomments) from Kaggle.
2. **Writing Transformation Functions**: We write Transformation Functions (TFs) that can be applied to training examples to generate new training examples.
3. **Applying Transformation Functions**: We apply a sequence of TFs to each training data point, using a random policy, to generate an augmented training set.
4. **Training A Model**: We use the augmented training set to train an LSTM model for classifying new comments as `SPAM` or `HAM`.

This next two cell takes care of some notebook-specific housekeeping.
You can ignore it.

In [1]:
import numpy as np
import os
import random

# Make sure we're running from the spam/ directory
if os.path.basename(os.getcwd()) == "snorkel-tutorials":
    os.chdir("spam")

# Turn off TensorFlow logging messages
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# For reproducibility
seed = 0
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(0)
random.seed(0)

This next cell makes sure a spaCy English model is downloaded.
If this is your first time downloading this model, restart the kernel after executing the next cell.

In [2]:
# Download the spaCy english model
! python -m spacy download en_core_web_sm



You should consider upgrading via the 'pip install --upgrade pip' command.[0m


[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_sm')


## 1. Loading Data

We load the Kaggle dataset and create Pandas DataFrame objects for each of the sets described above.
The two main columns in the DataFrames are:
* **`text`**: Raw text content of the comment
* **`label`**: Whether the comment is `SPAM` (1) or `HAM` (0).

For more details, check out the [labeling tutorial](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/01_spam_tutorial.ipynb).

In [3]:
from utils import load_spam_dataset

df_train, _, df_valid, df_test = load_spam_dataset(load_train_labels=True)

# We pull out the label vectors for ease of use later
Y_valid = df_valid["label"].values
Y_train = df_train["label"].values
Y_test = df_test["label"].values

In [4]:
df_train.head()

Unnamed: 0,author,date,text,label,video
0,Alessandro leite,2014-11-05T22:21:36,pls http://www10.vakinha.com.br/VaquinhaE.aspx...,1,1
1,Salim Tayara,2014-11-02T14:33:30,"if your like drones, plz subscribe to Kamal Ta...",1,1
2,Phuc Ly,2014-01-20T15:27:47,go here to check the views :3﻿,0,1
3,DropShotSk8r,2014-01-19T04:27:18,"Came here to check the views, goodbye.﻿",0,1
4,css403,2014-11-07T14:25:48,"i am 2,126,492,636 viewer :D﻿",0,1


## 2. Writing Transformation Functions

Transformation Functions are functions that can be applied to a training example to create another valid training example.
For example, for image classification problems, it is common to rotate or crop images in the training data to create new training inputs.
Transformation functions should be atomic e.g. a small rotation of an image, or changing a single word in a sentence.
We then compose multiple transformation functions when applying them to training examples.

Common ways to augment text includes replacing words with their synonyms, or replacing names entities with other entities.
More info can be found
[here](https://towardsdatascience.com/data-augmentation-in-nlp-2801a34dfc28) or
[here](https://towardsdatascience.com/these-are-the-easiest-data-augmentation-techniques-in-natural-language-processing-you-can-think-of-88e393fd610) .
Applying these operations to a comment shouldn't change whether it is `SPAM` or not.

Transformation functions in Snorkel are created with the
[`transformation_function` decorator](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.transformation_function.html#snorkel.augmentation.transformation_function),
which wraps a function that takes in a single data point and returns a transformed version of the data point.
If no transformation is possible, the function should return `None`.

Just like the `labeling_function` decorator, `transformation_function` accepts `pre` argument for `Preprocessor` objects.
Here, we'll use a
[`SpacyPreprocessor`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/preprocess/snorkel.preprocess.nlp.SpacyPreprocessor.html#snorkel.preprocess.nlp.SpacyPreprocessor).

In [5]:
from snorkel.preprocess.nlp import SpacyPreprocessor

spacy = SpacyPreprocessor(text_field="text", doc_field="doc", memoize=True)

In [6]:
import names
from snorkel.augmentation import transformation_function

replacement_names = [names.get_full_name() for _ in range(50)]

# Replace a random named entity with a different entity of the same type.
@transformation_function(pre=[spacy])
def change_person(x):
    person_names = [ent.text for ent in x.doc.ents if ent.label_ == "PERSON"]
    # If there is at least one person name, replace a random one. Else return None.
    if person_names:
        name_to_replace = np.random.choice(person_names)
        replacement_name = np.random.choice(replacement_names)
        x.text = x.text.replace(name_to_replace, replacement_name)
        return x


# Swap two adjectives at random.
@transformation_function(pre=[spacy])
def swap_adjectives(x):
    adjective_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "ADJ"]
    # Check that there are at least two adjectives to swap.
    if len(adjective_idxs) >= 2:
        idx1, idx2 = sorted(np.random.choice(adjective_idxs, 2, replace=False))
        # Swap tokens in positions idx1 and idx2.
        x.text = " ".join(
            [
                x.doc[:idx1].text,
                x.doc[idx2].text,
                x.doc[1 + idx1 : idx2].text,
                x.doc[idx1].text,
                x.doc[1 + idx2 :].text,
            ]
        )
        return x

We add some transformation functions that use `wordnet` from [NLTK](https://www.nltk.org/) to replace different parts of speech with their synonyms.

In [7]:
import nltk
from nltk.corpus import wordnet as wn

nltk.download("wordnet")


def get_synonym(word, pos=None):
    """Get synonym for word given its part-of-speech (pos)."""
    synsets = wn.synsets(word, pos=pos)
    # Return None if wordnet has no synsets (synonym sets) for this word and pos.
    if synsets:
        words = [lemma.name() for lemma in synsets[0].lemmas()]
        if words[0].lower() != word.lower():  # Skip if synonym is same as word.
            # Multi word synonyms in wordnet use '_' as a separator e.g. reckon_with. Replace it with space.
            return words[0].replace("_", " ")


def replace_token(spacy_doc, idx, replacement):
    """Replace token in position idx with replacement."""
    return " ".join([spacy_doc[:idx].text, replacement, spacy_doc[1 + idx :].text])


@transformation_function(pre=[spacy])
def replace_verb_with_synonym(x):
    # Get indices of verb tokens in sentence.
    verb_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "VERB"]
    if verb_idxs:
        # Pick random verb idx to replace.
        idx = np.random.choice(verb_idxs)
        synonym = get_synonym(x.doc[idx].text, pos="v")
        # If there's a valid verb synonym, replace it. Otherwise, return None.
        if synonym:
            x.text = replace_token(x.doc, idx, synonym)
            return x


@transformation_function(pre=[spacy])
def replace_noun_with_synonym(x):
    # Get indices of noun tokens in sentence.
    noun_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "NOUN"]
    if noun_idxs:
        # Pick random noun idx to replace.
        idx = np.random.choice(noun_idxs)
        synonym = get_synonym(x.doc[idx].text, pos="n")
        # If there's a valid noun synonym, replace it. Otherwise, return None.
        if synonym:
            x.text = replace_token(x.doc, idx, synonym)
            return x


@transformation_function(pre=[spacy])
def replace_adjective_with_synonym(x):
    # Get indices of adjective tokens in sentence.
    adjective_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "ADJ"]
    if adjective_idxs:
        # Pick random adjective idx to replace.
        idx = np.random.choice(adjective_idxs)
        synonym = get_synonym(x.doc[idx].text, pos="a")
        # If there's a valid adjective synonym, replace it. Otherwise, return None.
        if synonym:
            x.text = replace_token(x.doc, idx, synonym)
            return x

[nltk_data] Downloading package wordnet to /home/ubuntu/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


Let's check out a few examples of transformed data points to see what our TFs are doing.

In [8]:
import pandas as pd
from utils import preview_tfs

# Prevent truncating displayed sentences.
pd.set_option("display.max_colwidth", 0)
tfs = [
    change_person,
    swap_adjectives,
    replace_verb_with_synonym,
    replace_noun_with_synonym,
    replace_adjective_with_synonym,
]

preview_tfs(df_train, tfs)

Unnamed: 0,TF Name,Original Text,Transformed Text
0,change_person,Check out Berzerk video on my channel ! :D,Check out Jennifer Selby video on my channel ! :D
1,swap_adjectives,hey guys look im aware im spamming and it pisses people off but please take a moment to check out my music. im a young rapper and i love to do it and i just wanna share my music with more people just click my picture and then see if you like my stuff,hey guys look im aware im spamming and it pisses people off but please take a moment to check out my music. im a more rapper and i love to do it and i just wanna share my music with young people just click my picture and then see if you like my stuff
2,replace_verb_with_synonym,"""eye of the tiger"" ""i am the champion"" seems like katy perry is using titles of old rock songs for lyrics..﻿","""eye of the tiger"" ""i be the champion"" seems like katy perry is using titles of old rock songs for lyrics..﻿"
3,replace_noun_with_synonym,"Hey, check out my new website!! This site is about kids stuff. kidsmediausa . com","Hey, check out my new web site !! This site is about kids stuff. kidsmediausa . com"
4,replace_adjective_with_synonym,"I started hating Katy Perry after finding out that she stole all of the ideas on her videos from an old comic book. Yet, her music is catchy. ﻿","I started hating Katy Perry after finding out that she stole all of the ideas on her videos from an old amusing book. Yet, her music is catchy. ﻿"


We notice a couple of things about the TFs.
* Sometimes they make trivial changes (`"website"` to `"web site"` for replace_noun_with_synonym).
  This can still be helpful for training our model, because it teaches the model that these variations have similar meanings.
* Sometimes they make the sentence less meaningful (e.g. swapping `"young"` and `"more"` for swap_adjectives).

Data augmentation can be tricky for text inputs, so we expect most TFs to be a little flawed.
But these TFs can be useful despite the flaws; see [this paper](https://arxiv.org/pdf/1901.11196.pdf) for gains resulting from similar TFs.

## 3. Applying Transformation Functions

To apply one or more TFs that we've written to a collection of data points, we use a `TFApplier`.
Because our data points are represented with a Pandas DataFrame in this tutorial, we use a
[`PandasTFApplier`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.PandasTFApplier.html).
In addition, we can apply multiple TFs in a sequence to each example.
A `Policy` is used to determine what sequence of TFs to apply to each example.
In this case, we just use a [`MeanFieldPolicy`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.MeanFieldPolicy.html)
that picks 2 TFs at random per example, with probabilities given by `p`.
We give higher probabilities to the replace_X_with_synonym TFs, since those provide more information to the model.
The `n_per_original` argument determines how many augmented examples to generate per original example.

In [9]:
from snorkel.augmentation import MeanFieldPolicy, PandasTFApplier

policy = MeanFieldPolicy(
    len(tfs),
    sequence_length=2,
    n_per_original=2,
    keep_original=True,
    p=[0.05, 0.05, 0.3, 0.3, 0.3],
)
tf_applier = PandasTFApplier(tfs, policy)
df_train_augmented = tf_applier.apply(df_train)
Y_train_augmented = df_train_augmented["label"].values

  0%|          | 0/1586 [00:00<?, ?it/s]

  1%|          | 9/1586 [00:00<00:20, 77.22it/s]

  1%|          | 16/1586 [00:00<00:21, 71.93it/s]

  1%|▏         | 23/1586 [00:00<00:22, 69.03it/s]

  2%|▏         | 31/1586 [00:00<00:22, 70.14it/s]

  2%|▏         | 37/1586 [00:00<00:23, 65.46it/s]

  3%|▎         | 46/1586 [00:00<00:21, 70.50it/s]

  4%|▎         | 56/1586 [00:00<00:20, 76.14it/s]

  4%|▍         | 65/1586 [00:00<00:19, 79.25it/s]

  5%|▍         | 74/1586 [00:00<00:18, 80.42it/s]

  5%|▌         | 85/1586 [00:01<00:17, 86.85it/s]

  6%|▌         | 96/1586 [00:01<00:16, 91.27it/s]

  7%|▋         | 106/1586 [00:01<00:19, 74.73it/s]

  7%|▋         | 117/1586 [00:01<00:18, 79.31it/s]

  8%|▊         | 126/1586 [00:01<00:18, 76.88it/s]

  9%|▊         | 137/1586 [00:01<00:17, 81.73it/s]

  9%|▉         | 147/1586 [00:01<00:17, 83.52it/s]

 10%|▉         | 156/1586 [00:01<00:17, 83.18it/s]

 10%|█         | 165/1586 [00:02<00:17, 82.13it/s]

 11%|█         | 174/1586 [00:02<00:19, 73.58it/s]

 11%|█▏        | 182/1586 [00:02<00:21, 64.71it/s]

 12%|█▏        | 189/1586 [00:02<00:23, 60.04it/s]

 13%|█▎        | 199/1586 [00:02<00:21, 66.00it/s]

 13%|█▎        | 209/1586 [00:02<00:19, 72.28it/s]

 14%|█▎        | 218/1586 [00:02<00:17, 76.28it/s]

 14%|█▍        | 228/1586 [00:02<00:17, 79.38it/s]

 15%|█▌        | 238/1586 [00:03<00:16, 83.53it/s]

 16%|█▌        | 247/1586 [00:03<00:16, 81.65it/s]

 16%|█▌        | 256/1586 [00:03<00:16, 82.24it/s]

 17%|█▋        | 265/1586 [00:03<00:16, 79.83it/s]

 17%|█▋        | 274/1586 [00:03<00:16, 79.13it/s]

 18%|█▊        | 282/1586 [00:03<00:17, 76.25it/s]

 18%|█▊        | 290/1586 [00:03<00:17, 74.20it/s]

 19%|█▉        | 298/1586 [00:03<00:19, 66.11it/s]

 19%|█▉        | 305/1586 [00:04<00:20, 63.20it/s]

 20%|█▉        | 314/1586 [00:04<00:18, 68.60it/s]

 20%|██        | 324/1586 [00:04<00:17, 72.79it/s]

 21%|██        | 332/1586 [00:04<00:17, 69.93it/s]

 22%|██▏       | 341/1586 [00:04<00:16, 73.92it/s]

 22%|██▏       | 351/1586 [00:04<00:15, 79.71it/s]

 23%|██▎       | 360/1586 [00:04<00:15, 80.21it/s]

 23%|██▎       | 369/1586 [00:04<00:16, 74.82it/s]

 24%|██▍       | 377/1586 [00:04<00:16, 74.73it/s]

 24%|██▍       | 388/1586 [00:05<00:14, 81.08it/s]

 25%|██▌       | 397/1586 [00:05<00:17, 68.80it/s]

 26%|██▌       | 405/1586 [00:05<00:17, 67.18it/s]

 26%|██▌       | 413/1586 [00:05<00:17, 67.06it/s]

 27%|██▋       | 423/1586 [00:05<00:15, 73.77it/s]

 27%|██▋       | 433/1586 [00:05<00:15, 76.81it/s]

 28%|██▊       | 443/1586 [00:05<00:14, 80.76it/s]

 28%|██▊       | 452/1586 [00:05<00:13, 82.92it/s]

 29%|██▉       | 461/1586 [00:06<00:14, 76.28it/s]

 30%|██▉       | 469/1586 [00:06<00:14, 76.96it/s]

 30%|███       | 477/1586 [00:06<00:15, 73.04it/s]

 31%|███       | 485/1586 [00:06<00:15, 69.35it/s]

 31%|███       | 493/1586 [00:06<00:15, 71.76it/s]

 32%|███▏      | 501/1586 [00:06<00:16, 64.49it/s]

 32%|███▏      | 511/1586 [00:06<00:15, 71.30it/s]

 33%|███▎      | 519/1586 [00:06<00:15, 67.72it/s]

 33%|███▎      | 527/1586 [00:07<00:15, 67.61it/s]

 34%|███▎      | 535/1586 [00:07<00:14, 70.74it/s]

 34%|███▍      | 543/1586 [00:07<00:16, 64.84it/s]

 35%|███▍      | 552/1586 [00:07<00:14, 70.03it/s]

 35%|███▌      | 562/1586 [00:07<00:13, 75.88it/s]

 36%|███▌      | 570/1586 [00:07<00:13, 72.98it/s]

 37%|███▋      | 579/1586 [00:07<00:13, 76.97it/s]

 37%|███▋      | 587/1586 [00:07<00:14, 70.56it/s]

 38%|███▊      | 595/1586 [00:07<00:13, 70.95it/s]

 38%|███▊      | 603/1586 [00:08<00:13, 72.20it/s]

 39%|███▊      | 611/1586 [00:08<00:14, 69.63it/s]

 39%|███▉      | 619/1586 [00:08<00:13, 71.92it/s]

 40%|███▉      | 627/1586 [00:08<00:13, 72.48it/s]

 40%|████      | 635/1586 [00:08<00:13, 73.15it/s]

 41%|████      | 644/1586 [00:08<00:12, 77.40it/s]

 41%|████      | 654/1586 [00:08<00:12, 75.42it/s]

 42%|████▏     | 662/1586 [00:08<00:12, 76.60it/s]

 42%|████▏     | 671/1586 [00:08<00:11, 79.49it/s]

 43%|████▎     | 680/1586 [00:09<00:11, 77.08it/s]

 44%|████▎     | 691/1586 [00:09<00:10, 83.36it/s]

 44%|████▍     | 700/1586 [00:09<00:11, 78.57it/s]

 45%|████▍     | 711/1586 [00:09<00:10, 84.86it/s]

 45%|████▌     | 720/1586 [00:09<00:10, 83.07it/s]

 46%|████▌     | 731/1586 [00:09<00:09, 89.28it/s]

 47%|████▋     | 741/1586 [00:09<00:09, 86.31it/s]

 47%|████▋     | 752/1586 [00:09<00:09, 90.24it/s]

 48%|████▊     | 762/1586 [00:09<00:08, 92.67it/s]

 49%|████▊     | 773/1586 [00:10<00:08, 91.33it/s]

 49%|████▉     | 785/1586 [00:10<00:08, 95.74it/s]

 50%|█████     | 795/1586 [00:10<00:08, 91.09it/s]

 51%|█████     | 805/1586 [00:10<00:08, 87.85it/s]

 51%|█████▏    | 816/1586 [00:10<00:08, 92.14it/s]

 52%|█████▏    | 828/1586 [00:10<00:07, 95.16it/s]

 53%|█████▎    | 841/1586 [00:10<00:07, 101.94it/s]

 54%|█████▎    | 852/1586 [00:10<00:07, 98.10it/s] 

 54%|█████▍    | 864/1586 [00:10<00:07, 102.10it/s]

 55%|█████▌    | 875/1586 [00:11<00:07, 94.98it/s] 

 56%|█████▌    | 886/1586 [00:11<00:07, 93.46it/s]

 56%|█████▋    | 896/1586 [00:11<00:07, 93.91it/s]

 57%|█████▋    | 906/1586 [00:11<00:07, 91.05it/s]

 58%|█████▊    | 916/1586 [00:11<00:07, 84.13it/s]

 58%|█████▊    | 927/1586 [00:11<00:07, 90.10it/s]

 59%|█████▉    | 937/1586 [00:11<00:07, 82.23it/s]

 60%|█████▉    | 947/1586 [00:11<00:07, 85.97it/s]

 60%|██████    | 958/1586 [00:12<00:06, 91.64it/s]

 61%|██████    | 969/1586 [00:12<00:06, 93.10it/s]

 62%|██████▏   | 979/1586 [00:12<00:06, 88.45it/s]

 62%|██████▏   | 989/1586 [00:12<00:06, 88.16it/s]

 63%|██████▎   | 1000/1586 [00:12<00:06, 91.83it/s]

 64%|██████▎   | 1010/1586 [00:12<00:08, 66.23it/s]

 64%|██████▍   | 1020/1586 [00:12<00:07, 72.69it/s]

 65%|██████▍   | 1030/1586 [00:12<00:07, 79.12it/s]

 66%|██████▌   | 1040/1586 [00:13<00:06, 79.57it/s]

 66%|██████▌   | 1049/1586 [00:13<00:06, 81.39it/s]

 67%|██████▋   | 1059/1586 [00:13<00:06, 85.35it/s]

 67%|██████▋   | 1068/1586 [00:13<00:06, 85.04it/s]

 68%|██████▊   | 1077/1586 [00:13<00:06, 82.22it/s]

 68%|██████▊   | 1086/1586 [00:13<00:06, 77.02it/s]

 69%|██████▉   | 1097/1586 [00:13<00:05, 83.61it/s]

 70%|██████▉   | 1106/1586 [00:13<00:06, 79.58it/s]

 70%|███████   | 1115/1586 [00:14<00:05, 81.10it/s]

 71%|███████   | 1124/1586 [00:14<00:05, 82.98it/s]

 72%|███████▏  | 1134/1586 [00:14<00:05, 87.24it/s]

 72%|███████▏  | 1143/1586 [00:14<00:06, 70.34it/s]

 73%|███████▎  | 1151/1586 [00:14<00:07, 57.07it/s]

 73%|███████▎  | 1158/1586 [00:14<00:08, 52.32it/s]

 73%|███████▎  | 1164/1586 [00:14<00:10, 41.21it/s]

 74%|███████▍  | 1171/1586 [00:15<00:09, 45.87it/s]

 74%|███████▍  | 1177/1586 [00:15<00:10, 38.66it/s]

 75%|███████▍  | 1182/1586 [00:15<00:09, 41.46it/s]

 75%|███████▍  | 1187/1586 [00:15<00:10, 37.75it/s]

 75%|███████▌  | 1192/1586 [00:15<00:10, 39.29it/s]

 75%|███████▌  | 1197/1586 [00:15<00:11, 33.23it/s]

 76%|███████▌  | 1201/1586 [00:15<00:11, 34.94it/s]

 76%|███████▌  | 1205/1586 [00:16<00:11, 33.64it/s]

 76%|███████▋  | 1210/1586 [00:16<00:10, 35.09it/s]

 77%|███████▋  | 1216/1586 [00:16<00:09, 38.49it/s]

 77%|███████▋  | 1224/1586 [00:16<00:07, 45.49it/s]

 78%|███████▊  | 1230/1586 [00:16<00:08, 44.47it/s]

 78%|███████▊  | 1235/1586 [00:16<00:08, 39.60it/s]

 78%|███████▊  | 1240/1586 [00:16<00:09, 37.84it/s]

 78%|███████▊  | 1245/1586 [00:17<00:09, 37.23it/s]

 79%|███████▉  | 1251/1586 [00:17<00:08, 39.98it/s]

 79%|███████▉  | 1256/1586 [00:17<00:08, 38.50it/s]

 80%|███████▉  | 1261/1586 [00:17<00:09, 35.61it/s]

 80%|███████▉  | 1266/1586 [00:17<00:08, 37.36it/s]

 80%|████████  | 1272/1586 [00:17<00:07, 39.64it/s]

 81%|████████  | 1277/1586 [00:17<00:07, 39.31it/s]

 81%|████████  | 1282/1586 [00:17<00:07, 38.68it/s]

 81%|████████  | 1286/1586 [00:18<00:07, 38.16it/s]

 81%|████████▏ | 1290/1586 [00:18<00:09, 32.20it/s]

 82%|████████▏ | 1294/1586 [00:18<00:09, 31.34it/s]

 82%|████████▏ | 1298/1586 [00:18<00:09, 30.10it/s]

 82%|████████▏ | 1302/1586 [00:18<00:09, 30.18it/s]

 82%|████████▏ | 1306/1586 [00:18<00:10, 27.99it/s]

 83%|████████▎ | 1313/1586 [00:18<00:08, 32.94it/s]

 83%|████████▎ | 1317/1586 [00:19<00:07, 33.85it/s]

 84%|████████▎ | 1326/1586 [00:19<00:06, 41.61it/s]

 84%|████████▍ | 1332/1586 [00:19<00:06, 37.33it/s]

 84%|████████▍ | 1337/1586 [00:19<00:07, 34.91it/s]

 85%|████████▍ | 1343/1586 [00:19<00:06, 37.67it/s]

 85%|████████▍ | 1348/1586 [00:19<00:06, 36.80it/s]

 85%|████████▌ | 1353/1586 [00:19<00:06, 34.92it/s]

 86%|████████▌ | 1358/1586 [00:20<00:06, 36.95it/s]

 86%|████████▌ | 1363/1586 [00:20<00:05, 37.50it/s]

 86%|████████▌ | 1367/1586 [00:20<00:06, 31.89it/s]

 87%|████████▋ | 1373/1586 [00:20<00:06, 35.33it/s]

 87%|████████▋ | 1377/1586 [00:20<00:06, 32.57it/s]

 87%|████████▋ | 1381/1586 [00:20<00:06, 33.84it/s]

 87%|████████▋ | 1387/1586 [00:20<00:05, 37.78it/s]

 88%|████████▊ | 1392/1586 [00:21<00:05, 34.96it/s]

 88%|████████▊ | 1397/1586 [00:21<00:05, 35.20it/s]

 88%|████████▊ | 1403/1586 [00:21<00:04, 38.75it/s]

 89%|████████▉ | 1408/1586 [00:21<00:04, 38.19it/s]

 89%|████████▉ | 1415/1586 [00:21<00:03, 43.63it/s]

 90%|████████▉ | 1420/1586 [00:21<00:04, 34.59it/s]

 90%|████████▉ | 1425/1586 [00:21<00:04, 37.18it/s]

 90%|█████████ | 1430/1586 [00:22<00:04, 35.54it/s]

 90%|█████████ | 1434/1586 [00:22<00:04, 32.54it/s]

 91%|█████████ | 1440/1586 [00:22<00:04, 34.05it/s]

 91%|█████████ | 1444/1586 [00:22<00:05, 25.16it/s]

 91%|█████████▏| 1448/1586 [00:22<00:05, 25.65it/s]

 92%|█████████▏| 1454/1586 [00:22<00:04, 30.79it/s]

 92%|█████████▏| 1458/1586 [00:23<00:04, 28.60it/s]

 92%|█████████▏| 1462/1586 [00:23<00:04, 30.28it/s]

 93%|█████████▎| 1469/1586 [00:23<00:03, 35.41it/s]

 93%|█████████▎| 1474/1586 [00:23<00:03, 33.72it/s]

 93%|█████████▎| 1479/1586 [00:23<00:02, 37.00it/s]

 94%|█████████▎| 1484/1586 [00:23<00:02, 36.85it/s]

 94%|█████████▍| 1488/1586 [00:23<00:02, 34.95it/s]

 94%|█████████▍| 1492/1586 [00:23<00:03, 30.53it/s]

 94%|█████████▍| 1498/1586 [00:24<00:02, 34.66it/s]

 95%|█████████▍| 1504/1586 [00:24<00:02, 38.93it/s]

 95%|█████████▌| 1509/1586 [00:24<00:01, 41.04it/s]

 95%|█████████▌| 1514/1586 [00:24<00:02, 34.86it/s]

 96%|█████████▌| 1518/1586 [00:24<00:02, 28.24it/s]

 96%|█████████▌| 1523/1586 [00:24<00:01, 32.20it/s]

 96%|█████████▋| 1528/1586 [00:24<00:01, 34.92it/s]

 97%|█████████▋| 1532/1586 [00:25<00:01, 34.65it/s]

 97%|█████████▋| 1536/1586 [00:25<00:01, 35.85it/s]

 97%|█████████▋| 1540/1586 [00:25<00:01, 31.58it/s]

 97%|█████████▋| 1544/1586 [00:25<00:01, 29.89it/s]

 98%|█████████▊| 1548/1586 [00:25<00:01, 29.72it/s]

 98%|█████████▊| 1552/1586 [00:25<00:01, 28.03it/s]

 98%|█████████▊| 1555/1586 [00:25<00:01, 24.17it/s]

 98%|█████████▊| 1559/1586 [00:26<00:01, 26.05it/s]

 98%|█████████▊| 1562/1586 [00:26<00:00, 24.93it/s]

 99%|█████████▉| 1567/1586 [00:26<00:00, 28.67it/s]

 99%|█████████▉| 1573/1586 [00:26<00:00, 33.85it/s]

 99%|█████████▉| 1578/1586 [00:26<00:00, 35.40it/s]

100%|█████████▉| 1584/1586 [00:26<00:00, 36.58it/s]

100%|██████████| 1586/1586 [00:26<00:00, 59.34it/s]




In [10]:
print(f"Original training set size: {len(df_train)}")
print(f"Augmented training set size: {len(df_train_augmented)}")

Original training set size: 1586
Augmented training set size: 2486


We have almost doubled our dataset using TFs!
Note that despite `n_per_original` being set to 2, our dataset may not exactly triple in size, because sometimes TFs return `None` instead of a new example (e.g. `change_person` when applied to a sentence with no persons).

## 4. Training A Model

Our final step is to use the augmented data to train a model. We train an LSTM (Long Short Term Memory) model, which is a very standard architecture for text processing tasks.

The next cell makes Keras results reproducible. You can ignore it.

In [11]:
import tensorflow as tf

session_conf = tf.compat.v1.ConfigProto(
    intra_op_parallelism_threads=1, inter_op_parallelism_threads=1
)

tf.compat.v1.set_random_seed(0)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
tf.compat.v1.keras.backend.set_session(sess)

Now we'll train our LSTM on both the original and augmented data sets to compare performance.

In [12]:
from utils import featurize_df_tokens, get_keras_lstm, get_keras_early_stopping

X_train = featurize_df_tokens(df_train)
X_train_augmented = featurize_df_tokens(df_train_augmented)
X_valid = featurize_df_tokens(df_valid)
X_test = featurize_df_tokens(df_test)


def train_and_test(
    X_train,
    Y_train,
    X_valid=X_valid,
    Y_valid=Y_valid,
    X_test=X_test,
    Y_test=Y_test,
    num_buckets=30000,
):
    lstm_model = get_keras_lstm(num_buckets)
    lstm_model.fit(
        X_train,
        Y_train,
        epochs=25,
        validation_data=(X_valid, Y_valid),
        # Set up early stopping based on val set accuracy.
        callbacks=[get_keras_early_stopping(5)],
        verbose=0,
    )
    preds_test = lstm_model.predict(X_test)[:, 0] > 0.5
    return (preds_test == Y_test).mean()


acc_augmented = train_and_test(X_train_augmented, Y_train_augmented)
acc_original = train_and_test(X_train, Y_train)

print(f"Test Accuracy (original training data): {100 * acc_original:.1f}%")
print(f"Test Accuracy (augmented training data): {100 * acc_augmented:.1f}%")

W0814 23:34:12.924450 140672723703616 deprecation.py:506] From /home/ubuntu/snorkel-tutorials/.tox/spam/lib/python3.6/site-packages/tensorflow/python/keras/initializers.py:119: calling RandomUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


W0814 23:34:12.941456 140672723703616 deprecation.py:506] From /home/ubuntu/snorkel-tutorials/.tox/spam/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


W0814 23:34:13.133876 140672723703616 deprecation.py:323] From /home/ubuntu/snorkel-tutorials/.tox/spam/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


W0814 23:34:13.605324 140672723703616 deprecation.py:506] From /home/ubuntu/snorkel-tutorials/.tox/spam/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/adagrad.py:105: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Restoring model weights from the end of the best epoch.
Epoch 00017: early stopping


Restoring model weights from the end of the best epoch.
Epoch 00016: early stopping
Test Accuracy (original training data): 91.2%
Test Accuracy (augmented training data): 92.8%
