# 📈 Snorkel Intro Tutorial: Data Augmentation

In this tutorial, we will walk through the process of using *transformation functions* (TFs) to perform data augmentation.
Like the labeling tutorial, our goal is to train a classifier to YouTube comments as `SPAM` or `HAM` (not spam).
In the [previous tutorial](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/01_spam_tutorial.ipynb),
we demonstrated how to label training sets programmatically with Snorkel.
In this tutorial, we'll assume that step has already been done, and start with labeled training data,
which we'll aim to augment using transformation functions.


* For more details on the task, check out the [labeling tutorial](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/01_spam_tutorial.ipynb)
* For an overview of Snorkel, visit [snorkel.org](http://snorkel.org)
* You can also check out the [Snorkel API documentation](https://snorkel.readthedocs.io/)


Data augmentation is a popular technique for increasing the size of labeled training sets by applying class-preserving transformations to create copies of labeled data points.
In the image domain, it is a crucial factor in almost every state-of-the-art result today and is quickly gaining
popularity in text-based applications.
Snorkel models the data augmentation process by applying user-define *transformation functions* (TFs) in sequence.
You can learn more about data augmentation in
[this blog post about our NeurIPS 2017 work on automatically learned data augmentation](https://snorkel.org/tanda/).

The tutorial is divided into four parts:
1. **Loading Data**: We load a [YouTube comments dataset](http://www.dt.fee.unicamp.br/~tiago//youtubespamcollection/).
2. **Writing Transformation Functions**: We write Transformation Functions (TFs) that can be applied to training data points to generate new training data points.
3. **Applying Transformation Functions to Augment Our Dataset**: We apply a sequence of TFs to each training data point, using a random policy, to generate an augmented training set.
4. **Training a Model**: We use the augmented training set to train an LSTM model for classifying new comments as `SPAM` or `HAM`.

This next cell takes care of some notebook-specific housekeeping.
You can ignore it.

In [1]:
import os
import random

import numpy as np

# Make sure we're running from the spam/ directory
if os.path.basename(os.getcwd()) == "snorkel-tutorials":
    os.chdir("spam")

# Turn off TensorFlow logging messages
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# For reproducibility
seed = 0
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(0)
random.seed(0)

If you want to display all comment text untruncated, change `DISPLAY_ALL_TEXT` to `True` below.

In [2]:
import pandas as pd


DISPLAY_ALL_TEXT = False

pd.set_option("display.max_colwidth", 0 if DISPLAY_ALL_TEXT else 50)

This next cell makes sure a spaCy English model is downloaded.
If this is your first time downloading this model, restart the kernel after executing the next cell.

In [3]:
# Download the spaCy english model
! python -m spacy download en_core_web_sm



You should consider upgrading via the 'pip install --upgrade pip' command.[0m


[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_sm')


## 1. Loading Data

We load the Kaggle dataset and create Pandas DataFrame objects for each of the sets described above.
The two main columns in the DataFrames are:
* **`text`**: Raw text content of the comment
* **`label`**: Whether the comment is `SPAM` (1) or `HAM` (0).

For more details, check out the [labeling tutorial](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/01_spam_tutorial.ipynb).

In [4]:
from utils import load_spam_dataset

df_train, _, df_valid, df_test = load_spam_dataset(load_train_labels=True)

# We pull out the label vectors for ease of use later
Y_valid = df_valid["label"].values
Y_train = df_train["label"].values
Y_test = df_test["label"].values

In [5]:
df_train.head()

Unnamed: 0,author,date,text,label,video
0,Alessandro leite,2014-11-05T22:21:36,pls http://www10.vakinha.com.br/VaquinhaE.aspx...,1,1
1,Salim Tayara,2014-11-02T14:33:30,"if your like drones, plz subscribe to Kamal Ta...",1,1
2,Phuc Ly,2014-01-20T15:27:47,go here to check the views :3﻿,0,1
3,DropShotSk8r,2014-01-19T04:27:18,"Came here to check the views, goodbye.﻿",0,1
4,css403,2014-11-07T14:25:48,"i am 2,126,492,636 viewer :D﻿",0,1


## 2. Writing Transformation Functions (TFs)

Transformation functions are functions that can be applied to a training data point to create another valid training data point of the same class.
For example, for image classification problems, it is common to rotate or crop images in the training data to create new training inputs.
Transformation functions should be atomic e.g. a small rotation of an image, or changing a single word in a sentence.
We then compose multiple transformation functions when applying them to training data points.

Common ways to augment text includes replacing words with their synonyms, or replacing names entities with other entities.
More info can be found
[here](https://towardsdatascience.com/data-augmentation-in-nlp-2801a34dfc28) or
[here](https://towardsdatascience.com/these-are-the-easiest-data-augmentation-techniques-in-natural-language-processing-you-can-think-of-88e393fd610).
Our basic modeling assumption is that applying these operations to a comment generally shouldn't change whether it is `SPAM` or not.

Transformation functions in Snorkel are created with the
[`transformation_function` decorator](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.transformation_function.html#snorkel.augmentation.transformation_function),
which wraps a function that takes in a single data point and returns a transformed version of the data point.
If no transformation is possible, a TF can return `None` or the original data point.
If all the TFs applied to a data point return `None`, the data point won't be included in
the augmented dataset when we apply our TFs below.

Just like the `labeling_function` decorator, the `transformation_function` decorator
accepts `pre` argument for `Preprocessor` objects.
Here, we'll use a
[`SpacyPreprocessor`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/preprocess/snorkel.preprocess.nlp.SpacyPreprocessor.html#snorkel.preprocess.nlp.SpacyPreprocessor).

In [6]:
from snorkel.preprocess.nlp import SpacyPreprocessor

spacy = SpacyPreprocessor(text_field="text", doc_field="doc", memoize=True)

In [7]:
import names
from snorkel.augmentation import transformation_function

# Pregenerate some random person names to replace existing ones with
# for the transformation strategies below
replacement_names = [names.get_full_name() for _ in range(50)]


# Replace a random named entity with a different entity of the same type.
@transformation_function(pre=[spacy])
def change_person(x):
    person_names = [ent.text for ent in x.doc.ents if ent.label_ == "PERSON"]
    # If there is at least one person name, replace a random one. Else return None.
    if person_names:
        name_to_replace = np.random.choice(person_names)
        replacement_name = np.random.choice(replacement_names)
        x.text = x.text.replace(name_to_replace, replacement_name)
        return x


# Swap two adjectives at random.
@transformation_function(pre=[spacy])
def swap_adjectives(x):
    adjective_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "ADJ"]
    # Check that there are at least two adjectives to swap.
    if len(adjective_idxs) >= 2:
        idx1, idx2 = sorted(np.random.choice(adjective_idxs, 2, replace=False))
        # Swap tokens in positions idx1 and idx2.
        x.text = " ".join(
            [
                x.doc[:idx1].text,
                x.doc[idx2].text,
                x.doc[1 + idx1 : idx2].text,
                x.doc[idx1].text,
                x.doc[1 + idx2 :].text,
            ]
        )
        return x

We add some transformation functions that use `wordnet` from [NLTK](https://www.nltk.org/) to replace different parts of speech with their synonyms.

In [8]:
import nltk
from nltk.corpus import wordnet as wn

nltk.download("wordnet")


def get_synonym(word, pos=None):
    """Get synonym for word given its part-of-speech (pos)."""
    synsets = wn.synsets(word, pos=pos)
    # Return None if wordnet has no synsets (synonym sets) for this word and pos.
    if synsets:
        words = [lemma.name() for lemma in synsets[0].lemmas()]
        if words[0].lower() != word.lower():  # Skip if synonym is same as word.
            # Multi word synonyms in wordnet use '_' as a separator e.g. reckon_with. Replace it with space.
            return words[0].replace("_", " ")


def replace_token(spacy_doc, idx, replacement):
    """Replace token in position idx with replacement."""
    return " ".join([spacy_doc[:idx].text, replacement, spacy_doc[1 + idx :].text])


@transformation_function(pre=[spacy])
def replace_verb_with_synonym(x):
    # Get indices of verb tokens in sentence.
    verb_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "VERB"]
    if verb_idxs:
        # Pick random verb idx to replace.
        idx = np.random.choice(verb_idxs)
        synonym = get_synonym(x.doc[idx].text, pos="v")
        # If there's a valid verb synonym, replace it. Otherwise, return None.
        if synonym:
            x.text = replace_token(x.doc, idx, synonym)
            return x


@transformation_function(pre=[spacy])
def replace_noun_with_synonym(x):
    # Get indices of noun tokens in sentence.
    noun_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "NOUN"]
    if noun_idxs:
        # Pick random noun idx to replace.
        idx = np.random.choice(noun_idxs)
        synonym = get_synonym(x.doc[idx].text, pos="n")
        # If there's a valid noun synonym, replace it. Otherwise, return None.
        if synonym:
            x.text = replace_token(x.doc, idx, synonym)
            return x


@transformation_function(pre=[spacy])
def replace_adjective_with_synonym(x):
    # Get indices of adjective tokens in sentence.
    adjective_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "ADJ"]
    if adjective_idxs:
        # Pick random adjective idx to replace.
        idx = np.random.choice(adjective_idxs)
        synonym = get_synonym(x.doc[idx].text, pos="a")
        # If there's a valid adjective synonym, replace it. Otherwise, return None.
        if synonym:
            x.text = replace_token(x.doc, idx, synonym)
            return x

[nltk_data] Downloading package wordnet to /Users/braden/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [9]:
tfs = [
    change_person,
    swap_adjectives,
    replace_verb_with_synonym,
    replace_noun_with_synonym,
    replace_adjective_with_synonym,
]

Let's check out a few examples of transformed data points to see what our TFs are doing.

In [10]:
from utils import preview_tfs

preview_tfs(df_train, tfs)

Unnamed: 0,TF Name,Original Text,Transformed Text
0,change_person,Check out Berzerk video on my channel ! :D,Check out Jennifer Selby video on my channel ! :D
1,swap_adjectives,hey guys look im aware im spamming and it piss...,hey guys look im aware im spamming and it piss...
2,replace_verb_with_synonym,"""eye of the tiger"" ""i am the champion"" seems l...","""eye of the tiger"" ""i be the champion"" seems l..."
3,replace_noun_with_synonym,"Hey, check out my new website!! This site is a...","Hey, check out my new web site !! This site is..."
4,replace_adjective_with_synonym,I started hating Katy Perry after finding out ...,I started hating Katy Perry after finding out ...


We notice a couple of things about the TFs.

* Sometimes they make trivial changes (`"website"` to `"web site"` for replace_noun_with_synonym).
  This can still be helpful for training our model, because it teaches the model to be invariant to such small changes.
* Sometimes they introduce incorrect grammar to the sentence (e.g. `swap_adjectives` swapping `"young"` and `"more"` above).

The TFs are expected to be heuristic strategies that indeed preserve the class most of the time, but
[don't need to be perfect](https://arxiv.org/pdf/1901.11196.pdf).
This is especially true when using automated
[data augmentation techniques](https://snorkel.org/tanda/)
which can learn to avoid particularly corrupted data points.
As we'll see below, Snorkel is compatible with such learned augmentation policies.

## 3. Applying Transformation Functions

We'll first define a `Policy` to determine what sequence of TFs to apply to each data point.
We'll start with a [`RandomPolicy`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.RandomPolicy.html)
that samples `sequence_length=2` TFs to apply uniformly at random per data point.
The `n_per_original` argument determines how many augmented data points to generate per original data point.

In [11]:
from snorkel.augmentation import RandomPolicy

random_policy = RandomPolicy(
    len(tfs), sequence_length=2, n_per_original=2, keep_original=True
)

In some cases, we can do better than uniform random sampling.
We might have domain knowledge that some TFs should be applied more frequently than others,
or have trained an [automated data augmentation model](https://snorkel.org/tanda/)
that learned a sampling distribution for the TFs.
Snorkel supports this use case with a
[`MeanFieldPolicy`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.MeanFieldPolicy.html),
which allows you to specify a sampling distribution for the TFs.
We give higher probabilities to the `replace_[X]_with_synonym` TFs, since those provide more information to the model.

In [12]:
from snorkel.augmentation import MeanFieldPolicy

mean_field_policy = MeanFieldPolicy(
    len(tfs),
    sequence_length=2,
    n_per_original=2,
    keep_original=True,
    p=[0.05, 0.05, 0.3, 0.3, 0.3],
)

To apply one or more TFs that we've written to a collection of data points according to our policy, we use a
[`PandasTFApplier`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.PandasTFApplier.html)
because our data points are represented with a Pandas DataFrame.

In [13]:
from snorkel.augmentation import PandasTFApplier

tf_applier = PandasTFApplier(tfs, mean_field_policy)
df_train_augmented = tf_applier.apply(df_train)
Y_train_augmented = df_train_augmented["label"].values

  0%|          | 0/1586 [00:00<?, ?it/s]

  0%|          | 4/1586 [00:00<00:47, 33.64it/s]

  1%|          | 10/1586 [00:00<00:40, 38.50it/s]

  1%|          | 15/1586 [00:00<00:38, 40.33it/s]

  1%|▏         | 21/1586 [00:00<00:36, 42.78it/s]

  2%|▏         | 28/1586 [00:00<00:33, 46.41it/s]

  2%|▏         | 33/1586 [00:00<00:33, 46.81it/s]

  2%|▏         | 38/1586 [00:00<00:32, 47.47it/s]

  3%|▎         | 45/1586 [00:00<00:29, 52.17it/s]

  3%|▎         | 53/1586 [00:00<00:26, 57.51it/s]

  4%|▍         | 62/1586 [00:01<00:23, 63.78it/s]

  4%|▍         | 69/1586 [00:01<00:23, 64.86it/s]

  5%|▍         | 76/1586 [00:01<00:23, 64.65it/s]

  5%|▌         | 86/1586 [00:01<00:21, 71.23it/s]

  6%|▌         | 96/1586 [00:01<00:19, 76.29it/s]

  7%|▋         | 104/1586 [00:01<00:23, 62.50it/s]

  7%|▋         | 111/1586 [00:01<00:22, 64.23it/s]

  8%|▊         | 120/1586 [00:01<00:21, 68.70it/s]

  8%|▊         | 128/1586 [00:02<00:21, 67.50it/s]

  9%|▊         | 136/1586 [00:02<00:20, 69.52it/s]

  9%|▉         | 144/1586 [00:02<00:20, 71.51it/s]

 10%|▉         | 154/1586 [00:02<00:18, 77.75it/s]

 10%|█         | 163/1586 [00:02<00:19, 73.98it/s]

 11%|█         | 171/1586 [00:02<00:18, 75.28it/s]

 11%|█▏        | 179/1586 [00:02<00:25, 56.22it/s]

 12%|█▏        | 186/1586 [00:02<00:25, 54.90it/s]

 12%|█▏        | 193/1586 [00:03<00:25, 54.63it/s]

 13%|█▎        | 201/1586 [00:03<00:22, 60.26it/s]

 13%|█▎        | 210/1586 [00:03<00:20, 66.02it/s]

 14%|█▍        | 219/1586 [00:03<00:19, 71.16it/s]

 14%|█▍        | 228/1586 [00:03<00:18, 73.45it/s]

 15%|█▍        | 236/1586 [00:03<00:17, 75.04it/s]

 15%|█▌        | 245/1586 [00:03<00:17, 75.77it/s]

 16%|█▌        | 253/1586 [00:03<00:19, 69.13it/s]

 16%|█▋        | 261/1586 [00:03<00:18, 71.05it/s]

 17%|█▋        | 269/1586 [00:04<00:19, 67.43it/s]

 17%|█▋        | 276/1586 [00:04<00:19, 67.94it/s]

 18%|█▊        | 283/1586 [00:04<00:19, 65.27it/s]

 18%|█▊        | 290/1586 [00:04<00:20, 62.37it/s]

 19%|█▊        | 297/1586 [00:04<00:22, 56.24it/s]

 19%|█▉        | 303/1586 [00:04<00:22, 56.37it/s]

 19%|█▉        | 309/1586 [00:04<00:23, 55.42it/s]

 20%|██        | 318/1586 [00:04<00:20, 60.75it/s]

 20%|██        | 325/1586 [00:05<00:21, 59.24it/s]

 21%|██        | 332/1586 [00:05<00:23, 53.55it/s]

 21%|██▏       | 340/1586 [00:05<00:21, 58.53it/s]

 22%|██▏       | 350/1586 [00:05<00:19, 64.93it/s]

 23%|██▎       | 358/1586 [00:05<00:17, 68.23it/s]

 23%|██▎       | 366/1586 [00:05<00:19, 62.15it/s]

 24%|██▎       | 373/1586 [00:05<00:19, 62.21it/s]

 24%|██▍       | 380/1586 [00:05<00:19, 61.22it/s]

 25%|██▍       | 389/1586 [00:06<00:17, 67.68it/s]

 25%|██▌       | 397/1586 [00:06<00:20, 57.63it/s]

 25%|██▌       | 404/1586 [00:06<00:20, 57.05it/s]

 26%|██▌       | 412/1586 [00:06<00:19, 60.24it/s]

 26%|██▋       | 419/1586 [00:06<00:19, 58.92it/s]

 27%|██▋       | 430/1586 [00:06<00:17, 67.25it/s]

 28%|██▊       | 438/1586 [00:06<00:17, 65.03it/s]

 28%|██▊       | 447/1586 [00:06<00:16, 70.30it/s]

 29%|██▊       | 455/1586 [00:07<00:16, 70.36it/s]

 29%|██▉       | 463/1586 [00:07<00:18, 62.35it/s]

 30%|██▉       | 472/1586 [00:07<00:16, 67.26it/s]

 30%|███       | 480/1586 [00:07<00:18, 59.07it/s]

 31%|███       | 487/1586 [00:07<00:19, 57.23it/s]

 31%|███       | 494/1586 [00:07<00:20, 53.78it/s]

 32%|███▏      | 501/1586 [00:07<00:18, 57.15it/s]

 32%|███▏      | 510/1586 [00:07<00:16, 63.44it/s]

 33%|███▎      | 517/1586 [00:08<00:17, 60.97it/s]

 33%|███▎      | 524/1586 [00:08<00:17, 60.27it/s]

 33%|███▎      | 531/1586 [00:08<00:17, 59.81it/s]

 34%|███▍      | 538/1586 [00:08<00:17, 60.63it/s]

 34%|███▍      | 545/1586 [00:08<00:17, 57.91it/s]

 35%|███▍      | 555/1586 [00:08<00:15, 65.46it/s]

 35%|███▌      | 563/1586 [00:08<00:15, 65.71it/s]

 36%|███▌      | 571/1586 [00:08<00:15, 66.79it/s]

 37%|███▋      | 581/1586 [00:09<00:13, 71.83it/s]

 37%|███▋      | 589/1586 [00:09<00:15, 63.38it/s]

 38%|███▊      | 597/1586 [00:09<00:15, 64.33it/s]

 38%|███▊      | 606/1586 [00:09<00:13, 70.01it/s]

 39%|███▊      | 614/1586 [00:09<00:14, 65.97it/s]

 39%|███▉      | 622/1586 [00:09<00:14, 67.49it/s]

 40%|███▉      | 629/1586 [00:09<00:14, 66.23it/s]

 40%|████      | 636/1586 [00:09<00:14, 66.25it/s]

 41%|████      | 643/1586 [00:09<00:14, 66.41it/s]

 41%|████      | 650/1586 [00:10<00:14, 66.09it/s]

 41%|████▏     | 657/1586 [00:10<00:16, 57.78it/s]

 42%|████▏     | 665/1586 [00:10<00:15, 60.97it/s]

 43%|████▎     | 675/1586 [00:10<00:13, 66.46it/s]

 43%|████▎     | 682/1586 [00:10<00:13, 65.35it/s]

 44%|████▎     | 691/1586 [00:10<00:12, 71.02it/s]

 44%|████▍     | 699/1586 [00:10<00:13, 66.45it/s]

 45%|████▍     | 709/1586 [00:10<00:11, 73.34it/s]

 45%|████▌     | 717/1586 [00:11<00:11, 74.86it/s]

 46%|████▌     | 728/1586 [00:11<00:10, 81.54it/s]

 46%|████▋     | 737/1586 [00:11<00:10, 80.06it/s]

 47%|████▋     | 747/1586 [00:11<00:09, 84.40it/s]

 48%|████▊     | 759/1586 [00:11<00:09, 90.99it/s]

 49%|████▊     | 770/1586 [00:11<00:08, 94.23it/s]

 49%|████▉     | 780/1586 [00:11<00:08, 90.32it/s]

 50%|████▉     | 790/1586 [00:11<00:08, 90.94it/s]

 50%|█████     | 800/1586 [00:11<00:09, 83.20it/s]

 51%|█████     | 809/1586 [00:12<00:09, 82.64it/s]

 52%|█████▏    | 818/1586 [00:12<00:09, 83.28it/s]

 52%|█████▏    | 829/1586 [00:12<00:08, 88.47it/s]

 53%|█████▎    | 841/1586 [00:12<00:07, 94.58it/s]

 54%|█████▎    | 851/1586 [00:12<00:08, 87.75it/s]

 54%|█████▍    | 862/1586 [00:12<00:07, 92.42it/s]

 55%|█████▍    | 872/1586 [00:12<00:08, 87.95it/s]

 56%|█████▌    | 882/1586 [00:12<00:07, 88.64it/s]

 56%|█████▌    | 892/1586 [00:12<00:08, 85.04it/s]

 57%|█████▋    | 901/1586 [00:13<00:08, 82.17it/s]

 57%|█████▋    | 911/1586 [00:13<00:07, 85.92it/s]

 58%|█████▊    | 920/1586 [00:13<00:08, 81.81it/s]

 59%|█████▊    | 931/1586 [00:13<00:07, 82.74it/s]

 59%|█████▉    | 940/1586 [00:13<00:07, 81.84it/s]

 60%|█████▉    | 950/1586 [00:13<00:07, 84.71it/s]

 61%|██████    | 962/1586 [00:13<00:06, 91.73it/s]

 61%|██████▏   | 972/1586 [00:13<00:06, 89.26it/s]

 62%|██████▏   | 982/1586 [00:14<00:07, 85.25it/s]

 62%|██████▏   | 991/1586 [00:14<00:06, 86.07it/s]

 63%|██████▎   | 1001/1586 [00:14<00:06, 89.55it/s]

 64%|██████▎   | 1011/1586 [00:14<00:06, 91.05it/s]

 64%|██████▍   | 1021/1586 [00:14<00:06, 88.92it/s]

 65%|██████▍   | 1030/1586 [00:14<00:06, 88.35it/s]

 66%|██████▌   | 1040/1586 [00:14<00:06, 83.52it/s]

 66%|██████▌   | 1049/1586 [00:14<00:06, 83.02it/s]

 67%|██████▋   | 1059/1586 [00:14<00:06, 85.33it/s]

 67%|██████▋   | 1068/1586 [00:15<00:06, 78.75it/s]

 68%|██████▊   | 1077/1586 [00:15<00:07, 71.44it/s]

 68%|██████▊   | 1085/1586 [00:15<00:07, 67.39it/s]

 69%|██████▉   | 1095/1586 [00:15<00:06, 74.65it/s]

 70%|██████▉   | 1104/1586 [00:15<00:06, 75.90it/s]

 70%|███████   | 1112/1586 [00:15<00:06, 75.19it/s]

 71%|███████   | 1121/1586 [00:15<00:06, 76.41it/s]

 71%|███████▏  | 1131/1586 [00:15<00:05, 79.04it/s]

 72%|███████▏  | 1140/1586 [00:16<00:06, 72.88it/s]

 72%|███████▏  | 1148/1586 [00:16<00:07, 55.73it/s]

 73%|███████▎  | 1155/1586 [00:16<00:08, 49.80it/s]

 73%|███████▎  | 1161/1586 [00:16<00:10, 41.41it/s]

 74%|███████▎  | 1166/1586 [00:16<00:10, 40.70it/s]

 74%|███████▍  | 1172/1586 [00:16<00:09, 44.76it/s]

 74%|███████▍  | 1177/1586 [00:17<00:11, 34.39it/s]

 75%|███████▍  | 1182/1586 [00:17<00:10, 37.15it/s]

 75%|███████▍  | 1187/1586 [00:17<00:11, 33.85it/s]

 75%|███████▌  | 1192/1586 [00:17<00:10, 36.11it/s]

 75%|███████▌  | 1196/1586 [00:17<00:13, 29.60it/s]

 76%|███████▌  | 1200/1586 [00:17<00:12, 31.66it/s]

 76%|███████▌  | 1204/1586 [00:17<00:13, 28.91it/s]

 76%|███████▌  | 1209/1586 [00:18<00:11, 32.97it/s]

 77%|███████▋  | 1214/1586 [00:18<00:10, 34.77it/s]

 77%|███████▋  | 1222/1586 [00:18<00:08, 41.19it/s]

 77%|███████▋  | 1227/1586 [00:18<00:09, 38.46it/s]

 78%|███████▊  | 1232/1586 [00:18<00:09, 38.81it/s]

 78%|███████▊  | 1237/1586 [00:18<00:09, 37.37it/s]

 78%|███████▊  | 1242/1586 [00:18<00:11, 31.25it/s]

 79%|███████▊  | 1248/1586 [00:19<00:09, 35.72it/s]

 79%|███████▉  | 1253/1586 [00:19<00:08, 38.46it/s]

 79%|███████▉  | 1258/1586 [00:19<00:09, 34.05it/s]

 80%|███████▉  | 1262/1586 [00:19<00:10, 31.95it/s]

 80%|███████▉  | 1266/1586 [00:19<00:10, 31.95it/s]

 80%|████████  | 1271/1586 [00:19<00:08, 35.25it/s]

 80%|████████  | 1275/1586 [00:19<00:09, 33.75it/s]

 81%|████████  | 1279/1586 [00:20<00:10, 28.60it/s]

 81%|████████  | 1284/1586 [00:20<00:09, 32.27it/s]

 81%|████████  | 1288/1586 [00:20<00:11, 25.59it/s]

 81%|████████▏ | 1292/1586 [00:20<00:10, 28.33it/s]

 82%|████████▏ | 1296/1586 [00:20<00:11, 26.04it/s]

 82%|████████▏ | 1299/1586 [00:20<00:11, 26.04it/s]

 82%|████████▏ | 1302/1586 [00:20<00:10, 26.88it/s]

 82%|████████▏ | 1305/1586 [00:21<00:10, 27.48it/s]

 82%|████████▏ | 1308/1586 [00:21<00:10, 27.04it/s]

 83%|████████▎ | 1313/1586 [00:21<00:08, 30.44it/s]

 83%|████████▎ | 1317/1586 [00:21<00:08, 31.47it/s]

 84%|████████▎ | 1326/1586 [00:21<00:06, 38.70it/s]

 84%|████████▍ | 1331/1586 [00:21<00:07, 36.22it/s]

 84%|████████▍ | 1336/1586 [00:21<00:07, 31.66it/s]

 85%|████████▍ | 1341/1586 [00:21<00:06, 35.30it/s]

 85%|████████▍ | 1346/1586 [00:22<00:07, 33.91it/s]

 85%|████████▌ | 1351/1586 [00:22<00:07, 32.74it/s]

 85%|████████▌ | 1355/1586 [00:22<00:06, 33.81it/s]

 86%|████████▌ | 1359/1586 [00:22<00:07, 31.79it/s]

 86%|████████▌ | 1364/1586 [00:22<00:06, 35.36it/s]

 86%|████████▋ | 1368/1586 [00:22<00:07, 29.31it/s]

 87%|████████▋ | 1374/1586 [00:22<00:06, 34.07it/s]

 87%|████████▋ | 1378/1586 [00:23<00:06, 29.75it/s]

 87%|████████▋ | 1383/1586 [00:23<00:06, 32.60it/s]

 88%|████████▊ | 1388/1586 [00:23<00:05, 33.89it/s]

 88%|████████▊ | 1392/1586 [00:23<00:05, 32.88it/s]

 88%|████████▊ | 1397/1586 [00:23<00:05, 33.46it/s]

 88%|████████▊ | 1403/1586 [00:23<00:04, 37.56it/s]

 89%|████████▉ | 1408/1586 [00:23<00:04, 36.97it/s]

 89%|████████▉ | 1415/1586 [00:23<00:04, 42.28it/s]

 90%|████████▉ | 1420/1586 [00:24<00:05, 32.77it/s]

 90%|████████▉ | 1425/1586 [00:24<00:04, 35.12it/s]

 90%|█████████ | 1429/1586 [00:24<00:04, 31.84it/s]

 90%|█████████ | 1433/1586 [00:24<00:05, 30.51it/s]

 91%|█████████ | 1438/1586 [00:24<00:04, 34.12it/s]

 91%|█████████ | 1442/1586 [00:24<00:05, 25.51it/s]

 91%|█████████ | 1446/1586 [00:25<00:06, 22.29it/s]

 91%|█████████▏| 1450/1586 [00:25<00:05, 25.70it/s]

 92%|█████████▏| 1455/1586 [00:25<00:04, 29.32it/s]

 92%|█████████▏| 1459/1586 [00:25<00:05, 25.05it/s]

 92%|█████████▏| 1465/1586 [00:25<00:04, 30.18it/s]

 93%|█████████▎| 1469/1586 [00:25<00:03, 32.46it/s]

 93%|█████████▎| 1473/1586 [00:26<00:04, 25.42it/s]

 93%|█████████▎| 1477/1586 [00:26<00:04, 26.09it/s]

 93%|█████████▎| 1481/1586 [00:26<00:03, 26.88it/s]

 94%|█████████▍| 1487/1586 [00:26<00:03, 30.85it/s]

 94%|█████████▍| 1491/1586 [00:26<00:03, 25.42it/s]

 94%|█████████▍| 1495/1586 [00:26<00:03, 27.74it/s]

 95%|█████████▍| 1500/1586 [00:26<00:02, 31.62it/s]

 95%|█████████▌| 1507/1586 [00:27<00:02, 35.12it/s]

 95%|█████████▌| 1511/1586 [00:27<00:02, 32.59it/s]

 96%|█████████▌| 1515/1586 [00:27<00:02, 26.78it/s]

 96%|█████████▌| 1519/1586 [00:27<00:02, 25.90it/s]

 96%|█████████▌| 1525/1586 [00:27<00:01, 30.86it/s]

 96%|█████████▋| 1529/1586 [00:27<00:01, 30.11it/s]

 97%|█████████▋| 1533/1586 [00:27<00:01, 31.61it/s]

 97%|█████████▋| 1538/1586 [00:28<00:01, 32.87it/s]

 97%|█████████▋| 1542/1586 [00:28<00:01, 33.14it/s]

 97%|█████████▋| 1546/1586 [00:28<00:01, 24.69it/s]

 98%|█████████▊| 1550/1586 [00:28<00:01, 27.57it/s]

 98%|█████████▊| 1554/1586 [00:28<00:01, 25.44it/s]

 98%|█████████▊| 1557/1586 [00:28<00:01, 24.21it/s]

 98%|█████████▊| 1560/1586 [00:29<00:01, 22.63it/s]

 99%|█████████▊| 1565/1586 [00:29<00:00, 26.84it/s]

 99%|█████████▉| 1570/1586 [00:29<00:00, 31.03it/s]

 99%|█████████▉| 1577/1586 [00:29<00:00, 35.86it/s]

100%|█████████▉| 1582/1586 [00:29<00:00, 36.54it/s]

100%|██████████| 1586/1586 [00:29<00:00, 53.41it/s]




In [14]:
print(f"Original training set size: {len(df_train)}")
print(f"Augmented training set size: {len(df_train_augmented)}")

Original training set size: 1586
Augmented training set size: 2486


We have almost doubled our dataset using TFs!
Note that despite `n_per_original` being set to 2, our dataset may not exactly triple in size,
because sometimes TFs return `None` instead of a new data point
(e.g. `change_person` when applied to a sentence with no persons).
If you prefer to have exact proportions for your dataset, you can have TFs that can't perform a
valid transformation return the original data point rather than `None` (as they do here).

## 4. Training A Model

Our final step is to use the augmented data to train a model. We train an LSTM (Long Short Term Memory) model, which is a very standard architecture for text processing tasks.

The next cell makes Keras results reproducible. You can ignore it.

In [15]:
import tensorflow as tf

session_conf = tf.compat.v1.ConfigProto(
    intra_op_parallelism_threads=1, inter_op_parallelism_threads=1
)

tf.compat.v1.set_random_seed(0)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
tf.compat.v1.keras.backend.set_session(sess)

Now we'll train our LSTM on both the original and augmented datasets to compare performance.

In [16]:
from utils import featurize_df_tokens, get_keras_lstm, get_keras_early_stopping

X_train = featurize_df_tokens(df_train)
X_train_augmented = featurize_df_tokens(df_train_augmented)
X_valid = featurize_df_tokens(df_valid)
X_test = featurize_df_tokens(df_test)


def train_and_test(
    X_train,
    Y_train,
    X_valid=X_valid,
    Y_valid=Y_valid,
    X_test=X_test,
    Y_test=Y_test,
    num_buckets=30000,
):
    # Define a vanilla LSTM model with Keras
    lstm_model = get_keras_lstm(num_buckets)
    lstm_model.fit(
        X_train,
        Y_train,
        epochs=25,
        validation_data=(X_valid, Y_valid),
        callbacks=[get_keras_early_stopping(5)],
        verbose=0,
    )
    preds_test = lstm_model.predict(X_test)[:, 0] > 0.5
    return (preds_test == Y_test).mean()


acc_augmented = train_and_test(X_train_augmented, Y_train_augmented)
acc_original = train_and_test(X_train, Y_train)

W0815 13:35:24.221624 4466165184 deprecation.py:506] From /Users/braden/repos/snorkel-tutorials/.tox/spam/lib/python3.7/site-packages/tensorflow/python/keras/initializers.py:119: calling RandomUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


W0815 13:35:24.235807 4466165184 deprecation.py:506] From /Users/braden/repos/snorkel-tutorials/.tox/spam/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


W0815 13:35:24.448752 4466165184 deprecation.py:323] From /Users/braden/repos/snorkel-tutorials/.tox/spam/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


W0815 13:35:24.968317 4466165184 deprecation.py:506] From /Users/braden/repos/snorkel-tutorials/.tox/spam/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/adagrad.py:105: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Restoring model weights from the end of the best epoch.
Epoch 00017: early stopping


Restoring model weights from the end of the best epoch.
Epoch 00016: early stopping


In [17]:
print(f"Test Accuracy (original training data): {100 * acc_original:.1f}%")
print(f"Test Accuracy (augmented training data): {100 * acc_augmented:.1f}%")

Test Accuracy (original training data): 91.2%
Test Accuracy (augmented training data): 92.8%


So using the augmented dataset indeed improved our model!
There is a lot more you can do with data augmentation, so try a few ideas
our on your own!