# Data Augmentation

The code presented in this notebook is a modified version from a repository made available by Hemker (2018). Unfortunately, the author has since deleted the repository.

In [2]:
# Load packages
import os
import re
import gensim
import pickle
import pandas as pd
import numpy as np
from utils import preprocess
from nltk import pos_tag
from nltk.corpus import wordnet as wn
from gensim.models.keyedvectors import KeyedVectors

def get_corpus(corpus_, path=''):
    """Loads pre-trained word2vec model from src/ directory and
    returns a gensim word2vec object"""
    if corpus_ == 'google':
        return KeyedVectors.load_word2vec_format(path + 'GoogleNews-vectors-negative300.bin',
                                                 binary=True)
    if corpus_=='glove':
        return KeyedVectors.load_word2vec_format(path + 'glove.42B.300d.txt',
                                                 binary=False)
    if corpus_=='glove25':
        return KeyedVectors.load_word2vec_format(path + 'glove.twitter.27B.25d.txt',
                                                 binary=False)
    if corpus_=='fasttext':
        return KeyedVectors.load_word2vec_format(path + 'crawl-300d-2M.vec',
                                                 binary=False,
                                                 encoding='UTF-8')

class Augment():

    def __init__(self,
                 source_path,
                 target_path,
                 corpus_='none',
                 valid_tags=['NN'],
                 threshold=0.75,
                 x_col='tweet',
                 y_col='class',
                 path='',
                 model=None):
        """
        Constructor Arguments
        source_path (string): csv file that is meant to be augmented
        corpus_ (string): Word corpus that the similarity model should take in
            valid args: ['none', 'glove', 'fasttext', 'google']
        x_col (string): column name in csv from samples
        y_col (string): column name in csv for labels
        """
        if model is None:
            self.model = get_corpus(corpus_, path)  # Load model
            print('Loaded corpus: ', corpus_)
        else:
            self.model = model
           
        self.x_col=x_col
        self.y_col=y_col
        self.df=pd.read_csv(source_path)
        self.augmented=pd.DataFrame(columns=[x_col, y_col])
        self.valid_tags = valid_tags
        self.threshold_ = threshold
        
        # Store a mapping from original data point to augmented ones
        try:
            self.aug_idxs = pickle.load(open("data/augmentation_map.pickle", "rb"))
        except EnvironmentError:
            self.aug_idxs = dict()

        # Go through each row in dataframe
        for idx, row in self.df.iterrows():
           
            x = preprocess(row[self.x_col]).split()  # Preprocess input
            y = row[self.y_col]
            aug_temp = self.threshold(x)
            idx2aug = [] # mapping for this data point

            for elem in aug_temp:
                new_idx = self.augmented.shape[0] # index in the augmented dataset
                self.augmented.loc[new_idx] = [elem, y]
                idx2aug.append(new_idx)
               
            self.aug_idxs[idx] = idx2aug
            
            if (idx+1) % 10 == 0:
                print("{} rows successfully augmented.".format(idx+1))
                self.augmented.to_csv(target_path, encoding='utf-8')
                
                with open("data/augmentation_map.pickle", "wb") as f:
                    pickle.dump(self.aug_idxs, f, pickle.HIGHEST_PROTOCOL)
                
            if idx == 5: break
       
        print("Augmentation complete.")
        self.augmented.to_csv(target_path, encoding='utf-8')
        with open("data/augmentation_map.pickle", "wb") as f:
            pickle.dump(self.aug_idxs, f, pickle.HIGHEST_PROTOCOL)


    def create_augmented_samples(self, dict, n, x):
        """Function receives a dictionary which contains the acceptable substitutions for each
           word in x."""
        aug_tweets = [' '.join(x)]  # Save original tweet
       
        # For each possible substitution
        for i in range(n):  
           
            # copy the original tweet
            single_augment = x.copy()  
           
            # For each word in the tweet
            for idx, word in enumerate(single_augment):
               
                # If the word can be replaced and we haven't used all the possible replacements before
                if word in dict.keys() and len(dict[word]) >= i+1:
                   
                    # Replace that word
                    single_augment[idx] = dict[word][i]
                   
            # Join the words into a sentence
            single_augment = ' '.join(single_augment)
           
            # Save the augmented tweet
            aug_tweets.append(single_augment)
           
        return aug_tweets


    def threshold(self, x):
       
        # Create a dictionary that will save the possible replacements for each word
        dict = {}
        n = 0
       
        # Generate POS tags for the words in sentence x
        tags = pos_tag(x)  
       
        for idx, word in enumerate(x):  # For each word in x
           
            # Check if word is part of the vocabulary
            if word in self.model.vocab:  
               
                #get words with highest cosine similarity
                replacements = self.model.most_similar(positive=word, topn=5)
               
                #keep only words that pass the threshold
                replacements = [replacements[i][0] for i in range(5) if replacements[i][1] > self.threshold_]
               
                #check for POS tag equality, dismiss if unequal
                replacements = [elem for elem in replacements if pos_tag([elem.lower()])[0][1] == tags[idx][1]]
               
                #update dictionary with possible replacements for key word
                dict.update({word:replacements}) if len(replacements) > 0 else dict
                n = max(len(replacements), n) #update largest number of replacements
       
        return self.create_augmented_samples(dict, n, x)

In [3]:
source_path = 'data/labeled_data.csv'
target_path = 'data/augmented_data.csv'
corpus_='glove'
path="./glove/"  # Corpus path

In [4]:
model = get_corpus(corpus_, path) 

In [5]:
# May take several hours.
Augment(source_path=source_path, target_path=target_path, corpus_=corpus_, path=path, model=model)

Augmentation complete.


<__main__.Augment at 0x25242d7a588>