In [11]:
from transformers import BertTokenizer, BertForMaskedLM
from transformers import RobertaTokenizer, RobertaForMaskedLM
from nltk.tokenize.treebank import TreebankWordDetokenizer

# from transformers import BertForQuestionAnswering, AutoModelForQuestionAnswering, AutoTokenizer
# from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering

import os
import torch
import torch.nn as nn
import json
import re
import timeit
from copy import copy
from tqdm import tqdm

import contractions

import nltk
from nltk import word_tokenize

from proj_utils import *
from proj_config import *

In [2]:
# Location for augmented test files
output_dir = '/data/augmentation/test/'

In [3]:


# Required for identifying parts of speech. 
#TODO should we do this using Bert based model?
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True


# Use BERT or RoBERTa models

In [4]:
model_type = 'roberta'

# Download pre-trained models from huggingface
if model_type == 'bert':
    # Download (Using cased to maintain case in output)
    tokenizer = BertTokenizer.from_pretrained('bert-large-cased')
    model = BertForMaskedLM.from_pretrained('bert-large-cased')
    model_mask = '[MASK]'
    
elif model_type == 'roberta':
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    model = RobertaForMaskedLM.from_pretrained('roberta-base')
    model_mask = '<mask>'
    

Some weights of RobertaForMaskedLM were not initialized from the model checkpoint at roberta-base and are newly initialized: ['lm_head.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
print(type(model))

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)

model = model.to(device)

print(type(model))

<class 'transformers.modeling_roberta.RobertaForMaskedLM'>
<class 'torch.nn.parallel.data_parallel.DataParallel'>


In [6]:
torch.cuda.device_count()

1

### Convert Wiki word frequencies in to json.gz

In [7]:
# wiki_word_count_file = "/data/misc/enwiki-20190320-words-frequency.txt"

# wiki_freq_dict = {}

# with open(wiki_word_count_file, 'r', encoding='utf-8') as f:
#     for line in f:
#         word_count = line.split(' ')
#         wiki_freq_dict[word_count[0]] = int(word_count[1][:-1])

# write_gzip_json("/data/misc/enwiki-20190320-words-frequency.json.gz", wiki_freq_dict)

In [8]:
wiki_freq_dict = get_gzip_json_url('https://nlp-distribution.s3.ca-central-1.amazonaws.com/misc/enwiki-20190320-words-frequency.json.gz')

Fetching: https://nlp-distribution.s3.ca-central-1.amazonaws.com/misc/enwiki-20190320-words-frequency.json.gz


In [12]:
def masked_word_prediction(text):
    # Text is already masked.  Find masked words, predict replacements and replace them.  
    global model_type
    predicted_words = []
    full_paragraph = ''
    
#     text_token_ids = tokenizer.encode(text, return_tensors='pt')
#     tokenized_text = tokenizer.tokenize(text, return_tensors='pt')

    text_token_ids = tokenizer.encode(text, return_tensors='pt')
    tokenized_text = tokenizer.tokenize(text)
    
    # Nested function so we can break paragraph in to parts when > 512 tokens long. 
    def predict_part(text_token_ids_part, tokenized_text_part):
        # Get and format positions of work masks
        mask_positions_2d = (text_token_ids_part.squeeze() == tokenizer.mask_token_id).nonzero()
        mask_positions = [mask.item() for mask in mask_positions_2d ]

        with torch.no_grad():
            output = model(text_token_ids_part)

        last_hidden_state = output[0].squeeze()

        token_predictions_list =[]
        for mask_index in mask_positions:
            mask_hidden_state = last_hidden_state[mask_index]
            # This isn't really required unless we want > 1 predicted word per mask. 
            idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]

            # The different models tokenize differently, so handle based on model type. 
            words = None
            if model_type == 'roberta':
                words = [tokenizer.decode(i.item()).strip() for i in idx]
            elif model_type == 'bert':
                words = tokenizer.decode(idx).split(' ')
                
            for i in range(len(words)):
                if (words[i][0:2] == '##'):
                    words[i] = words[i][2:]

            token_predictions_list.append(words)

    
        # Make sure we have a list of predictions for each mask. 
        assert len(mask_positions) == len(token_predictions_list)

        # Replace masks with predicted words
        #   Make a copy so we can calculate shape differences, required for answer_start in QnA.
        predicted_text = copy(tokenized_text_part) 
        for pos, new_word in zip(mask_positions, token_predictions_list):

            #TODO This is where we could search for a more appropriate replacement. 
            # Add the weird G at the front of predicted words, so they look like recognized tokens for the 
            # detokenization logic used in 'detokenize()'. 
            if model_type == 'roberta':
                predicted_word = 'Ġ'+new_word[0]
            else:
                predicted_word = new_word[0]

            # Ugly hack, not sure why this goes askew
            if predicted_text[pos-2] == "<mask>":
                predicted_text[pos-2] = predicted_word
            else:
                predicted_text[pos-1] = predicted_word
            predicted_words.append(new_word[0])

        return detokenize(predicted_text)
    

    # Split text in multiple parts if text is longer than 512 tokens. 
    if len(tokenized_text) > 512:
        # Let's figure out how many splits.
        start_pos = 0
        end_pos = 0
        piece_size = 0
        pieces = 0
        
        for i in range(100):
            denom = i + 1
            if int(len(tokenized_text)/denom) < 500:
                piece_size = int(len(tokenized_text)/denom)
                pieces = denom
                break
        
#         print("Going to break the para in to {} pieces. End: {}".format(pieces, end_pos))
        
        for i in range(pieces):
            end_pos = piece_size * (i+1)
            
#             print("Doing piece {} Start: {} End: {}".format(i+1, start_pos, end_pos))
            sized_text = []
            split_idx = 0
            
            if i == (pieces-1):
                # Grab to the end. 
                split_idx = None
            else:
                # Look for the first sentence ending after the split position. 
                split_idx = end_pos
                for idx, token in enumerate(tokenized_text[end_pos:], start=end_pos):
                    if token in '.!?' or (idx-end_pos) >= 512:
                        split_idx = idx+1
                        break
            
#             print(tokenized_text[start_pos:split_idx])
#             print("start_pos: {} split_idx: {} end_pos: {}".format(start_pos,split_idx, end_pos))
            full_paragraph += predict_part(text_token_ids[0:1,start_pos:split_idx], tokenized_text[start_pos:split_idx])
            start_pos = split_idx
    else:
        full_paragraph += predict_part(text_token_ids, tokenized_text)
        
    if full_paragraph.find("\<mask\>") >= 0:
        print(full_paragraph)
        
    return full_paragraph, predicted_words

def detokenize(tokenized_text):
    global model_type
    postfix_symbols="?:!.,;%)"
    prefix_symbols="($"
    
    detokenized_sentence=''
    for idx, token in enumerate(tokenized_text):
        if idx == 0:
            # Just to ensure we can assume text exists in detokenized_sentence below.  
            detokenized_sentence = token
        elif model_type == 'roberta' and token[0:1] != 'Ġ': 
            # RoBERTa adds 'Ġ' at the beginning of tokens if they are in the vocabulary.
            # if they are not present, then it has split the word in to multiple tokens 
            # and need to be reconnected. 
            detokenized_sentence += token
        elif token[0:2] == '##':
            # Bert splits unrecognized tokens and prepends '##', this rejoins those.  
            detokenized_sentence += token[2:]
        elif token in postfix_symbols:
            # Don't add space before punctuation. 
            detokenized_sentence += token
        elif  detokenized_sentence[-1:] in prefix_symbols:
            # Don't add space after dollar (and things like that. 
            detokenized_sentence += token
        elif detokenized_sentence[-1:] in ["'", '"', '-'] or token[0:1] in ["'", '"', '-']:
            # eliminates spaces before or after apostrophe. 
            detokenized_sentence += token
        else:
            # Looks like a normal word add a space. 
            detokenized_sentence += ' '+token
        
    if model_type == 'roberta':
        # Remove the leftover weird Ġ's. 
        detokenized_sentence = detokenized_sentence.replace('Ġ', '')
    
    detokenized_sentence = detokenized_sentence.replace("` ` ", '"')
    
    return detokenized_sentence

## NLTK parts of speech 
# CC coordinating conjunction
# CD cardinal digit
# DT determiner
# EX existential there (like: "there is" ... think of it like "there exists")
# FW foreign word
# IN preposition/subordinating conjunction
# JJ adjective 'big'
# JJR adjective, comparative 'bigger'
# JJS adjective, superlative 'biggest'
# LS list marker 1)
# MD modal could, will
# NN noun, singular 'desk'
# NNS noun plural 'desks'
# NNP proper noun, singular 'Harrison'
# NNPS proper noun, plural 'Americans'
# PDT predeterminer 'all the kids'
# POS possessive ending parent's
# PRP personal pronoun I, he, she
# PRP$ possessive pronoun my, his, hers
# RB adverb very, silently,
# RBR adverb, comparative better
# RBS adverb, superlative best
# RP particle give up
# TO to go 'to' the store.
# UH interjection errrrrrrrm
# VB verb, base form take
# VBD verb, past tense took
# VBG verb, gerund/present participle taking
# VBN verb, past participle taken
# VBP verb, sing. present, non-3d take
# VBZ verb, 3rd person sing. present takes
# WDT wh-determiner which
# WP wh-pronoun who, what
# WP$ possessive wh-pronoun whose
# WRB wh-abverb where, when

# Pass sentence and list of parts of speach to be masked.
# Returns masked sentence and list of words replaced by mask. 
def mask_sentence_by_part(sentence, part):
    tokenized_sentence = word_tokenize(expand_contractions(sentence))
    word_tags = nltk.pos_tag(tokenized_sentence)
    
    masked_words = []
    
    for idx, word_tag in enumerate(word_tags):
        if (word_tag[1] in part):
            masked_words.append(tokenized_sentence[idx])
            tokenized_sentence[idx] = model_mask

    return TreebankWordDetokenizer().detokenize(tokenized_sentence), masked_words

def expand_contractions(text):
    assert len(text[0]) == 1
    
    return text[0] + contractions.fix(text)[1:]

def augment_language_by_part(sentence, part):
    masked_sentence, masked_words = mask_sentence_by_part(sentence, part)
    new_sentence, predicted_words = masked_word_prediction(masked_sentence)
    return new_sentence, masked_words, predicted_words

def augment_language_by_frequency(sentence, percentile):
    masked_sentence, masked_words = mask_sentence_by_frequency(sentence, percentile)
    new_sentence, predicted_words = masked_word_prediction(masked_sentence)
    return new_sentence, masked_words, predicted_words


# Mask word based on their frequency, making the lower "percentile" passed. 
def mask_sentence_by_frequency(sentence, percentile):
    tokenized_sentence = word_tokenize(expand_contractions(sentence))
    
    masked_words = []
    word_frequencies = {}
    
    
    for idx, word in enumerate(tokenized_sentence):
        if word in wiki_freq_dict:
            word_frequencies[word] = wiki_freq_dict[word]
            
    sorted_word_frequencies = {k: v for k, v in sorted(word_frequencies.items(), key=lambda item: item[1])}

    n=int(len(sorted_word_frequencies)*percentile)

    low_frequency_words = {k: sorted_word_frequencies[k] for k in list(sorted_word_frequencies)[:n]}

    # mask words, and keep track of what words were masked
    for i, word in enumerate(tokenized_sentence):
        if word in low_frequency_words:
            masked_words.append(word)
            tokenized_sentence[i] = model_mask

    return TreebankWordDetokenizer().detokenize(tokenized_sentence), masked_words


def augment_answers(orig_answer, orig_answer_start, augmented_paragraph, masked_words, predicted_words):
    # Find answers in the augmented paragraph.
    # Returns new answer and new start position
    # Throws assertion if augmented answer isn't found. 

#     augmented_answer_re = orig_answer
    expanded_answer = expand_contractions(orig_answer)
    
    # Escape some chars
#     augmented_answer_re = re.escape(expanded_answer)

    # a terrible way to escape to keep spaces unescaped, might instead use re.escape and then unescape spaces. 
    augmented_answer_re = expanded_answer \
        .replace("\\", "\\\\") \
        .replace(")", "\)") \
        .replace("(", "\(") \
        .replace("]", "\]") \
        .replace("[", "\[") \
        .replace("}", "\}") \
        .replace("{", "\{") \
        .replace(">", "\>") \
        .replace("<", "\<") \
        .replace('^', '\^') \
        .replace('|', '\|') \
        .replace('*', '\*') \
        .replace('?', '\?') \
        .replace('+', '\+') \
        .replace("``", '"')

    # Ignore whitespace around some chars
    augmented_answer_re = re.sub(r'([.,$=/\-"\'])', r' *\1 *', augmented_answer_re.strip())
    augmented_answer_re = augmented_answer_re.replace('$', '\$')

    already_swapped = []
    for masked_word, predicted_word in zip(masked_words, predicted_words):
        if re.escape(masked_word) != re.escape(predicted_word) and not masked_word in already_swapped:
            already_swapped.append(masked_word)
            word_pair_re = " *(" + re.escape(masked_word) 
            
            
            for sub_masked_word, sub_predicted_word in zip(masked_words, predicted_words):
                if sub_masked_word == masked_word:
                    word_pair_re += "|" + re.escape(sub_predicted_word)

            word_pair_re += ") *"

            # Swap out masked words with regex matching either as whole word. 
            augmented_answer_re = re.sub(r'\b'+re.escape(masked_word)+r'\b', word_pair_re, augmented_answer_re)

    augmented_answer_re = augmented_answer_re.replace(' * * *', ' *')
    augmented_answer_re = augmented_answer_re.replace(' * *', ' *')
    augmented_answer_re = augmented_answer_re.replace(' * ', ' *')
    augmented_answer_re = augmented_answer_re.replace('  ', ' ')

#     print("augmented_answer_re:\n{}\naugmented_paragraph:\n{}\n\n\n".format(r'('+augmented_answer_re+r')', augmented_paragraph))

    # Find all the matches
    matches = re.finditer(r'('+augmented_answer_re+r')', augmented_paragraph, re.IGNORECASE)
    
    # Find the match closest in start_position and use that one. 
    closest_match = None
    closest_position_delta = 0
    for match in matches:
        position_delta = abs(match.span()[0] - orig_answer_start)
        
        if closest_match is None or position_delta < closest_position_delta:
            closest_match = match
            closest_position_delta = position_delta
            
#     assert closest_match, augmented_answer_re
    # If there is no match, alert and move on. 
    if not closest_match:
#         print("*** Could not find\nOrig:'{}'\nAug:'{}' in :\n{}\n".format(orig_answer, augmented_answer_re, augmented_paragraph))
        new_answer = None
        new_start_pos = None
    else:
        new_answer = closest_match.group()
        new_start_pos = closest_match.span()[0]
    
    return new_answer, new_start_pos

## Run It

In [15]:

def process_qa(qa_json, parts_of_speech=None, frequency_percentile=None):
    assert bool(parts_of_speech) ^ bool(frequency_percentile), "Can only pass one of parts_of_speech and frequency_percentile"

    if parts_of_speech:
        print(f"Processing parts of speech: {parts_of_speech}".format(qa_json, parts_of_speech))
    elif frequency_percentile:
        print(f"Processing frequncy percentile: {frequency_percentile}")
        
    
    total_time = 0
    paragraph_count = 0

    results = {}
    count = 0

    augmented_json = {"data": []}

    paragraphs = []
    augmented_json["data"].append(paragraphs)

    ## JSON Format
    # -data
    #   -paragraphs
    #     -context
    #     -qas
    #       -question
    #       -id
    #       -answers
    #          -text
    #          -answer_start
    #   -title
    #   -split

#     with open(qa_json) as json_file:
#         qa_json = json.load(json_file)['data']
#     qa_json = get_json_url(qa_json)['data']


    new_data = []
    new_splits = []
    new_data.append(new_splits)
    new_json = {"data": new_data}

    new_splits = []
    new_json = {"data": new_splits}

    for splits in qa_json:
        paragraphs = splits['paragraphs']

        new_paragraphs = []
        new_split = {"paragraphs": new_paragraphs, "split": splits['split'], "title": splits['title']}
        new_splits.append(new_split)


        for paragraph in paragraphs:
            start = timeit.default_timer()
            count += 1

            if parts_of_speech:
                augmented_paragraph, masked_words, predicted_words = augment_language_by_part(
                    paragraph['context'], parts_of_speech
                )
            elif frequency_percentile:
                augmented_paragraph, masked_words, predicted_words = augment_language_by_frequency(
                    paragraph['context'], frequency_percentile
                )

            new_qas = []
            new_paragraph = {"context": augmented_paragraph, "qas": new_qas}
            new_paragraphs.append(new_paragraph)

            for qa in paragraph['qas']:
                #TODO Skip any with missing answers
                keep_question = True

                new_answers = []

                for answer in qa['answers']:
                    aug_answer_text, aug_answer_start = augment_answers(answer['text'], answer['answer_start'], augmented_paragraph, masked_words, predicted_words)

                    if aug_answer_start is None:
                        keep_question = False

                    new_answer = {"answer_start": aug_answer_start, "text": aug_answer_text}
                    new_answers.append(new_answer)

                if keep_question:
                    new_qas.append({"answers": new_answers, "id": qa['id'], "question": qa['question']})



            stop = timeit.default_timer()
            run_time = stop - start

            total_time += run_time
            paragraph_count += 1
    return new_json

def get_augmented_filename(output_dir, question_set, parts_of_speech=None, frequency_percentile=None):
    assert bool(parts_of_speech) ^ bool(frequency_percentile), "Can only pass one of parts_of_speech and frequency_percentile"
    
    if parts_of_speech:
        filename = output_dir+question_set+"_"+model_type+"_"+"_".join(parts_of_speech)+".json.gz"
    elif frequency_percentile:
        filename = output_dir+question_set+"_"+model_type+"_Percentile_"+str(frequency_percentile)+".json.gz"
    
    return filename
    

def write_json(augmented_qa_json, output_dir, question_set, parts_of_speech=None, frequency_percentile=None):
    assert bool(parts_of_speech) ^ bool(frequency_percentile), "Can only pass one of parts_of_speech and frequency_percentile"
    filename = get_augmented_filename(output_dir, question_set, parts_of_speech, frequency_percentile)

    return write_gzip_json(filename)


## These dicts are stored in proj_config.py.  Uncomment here to override and try different variations. 
# qa_urls = {
#     "amazon_reviews_v1_0": 'https://ndownloader.figshare.com/files/21500109?private_link=2f119bea3e8d711047ec',
#     "reddit_v1_0": 'https://ndownloader.figshare.com/files/21500112?private_link=2f119bea3e8d711047ec',
#     "new_wiki_v1.0": 'https://ndownloader.figshare.com/files/21500115?private_link=2f119bea3e8d711047ec',
#     "nyt_v1.0": 'https://ndownloader.figshare.com/files/21500118?private_link=2f119bea3e8d711047ec',
# }

# parts_of_speech_list = [
#     ['JJ', 'VB'],
#     ['JJ'],
#     ['VB', 'RB'],
#     ['VB'],
#     ['RB'],
#     ['RB', 'RBR', 'RBZ'],
#     ['VB', 'VBD', 'VBG', 'VBN', 'VBP'],
#     ['RB', 'RBR', 'RBZ', 'VB', 'VBD', 'VBG' 'VBN', 'VBP']
# ]

# frequency_percentiles = [
#     0.10,
#     0.20,
#     0.30,
#     0.50
# ]

qa_files = {}

# Cache the test sets
for name, url in qa_urls.items():
    qa_files[name] = get_json_url(url)['data']

for question_set, filename in qa_files.items():
    for parts_of_speech in parts_of_speech_list:
        filepath = get_augmented_filename(output_dir, question_set, parts_of_speech=parts_of_speech)
        
        if not os.path.exists(filepath):
            augmented_qa_json = process_qa(filename, parts_of_speech=parts_of_speech)
#             write_json(augmented_qa_json, output_dir, question_set, parts_of_speech=parts_of_speech)
            filename = get_augmented_filename(output_dir, question_set, parts_of_speech=parts_of_speech)
            write_gzip_json(filename, augmented_qa_json)
        else:
            print("Skipping existing output: {}".format(filepath))
        
    for frequency_percentile in frequency_percentiles:
        filepath = get_augmented_filename(output_dir, question_set, frequency_percentile=frequency_percentile)
        if not os.path.exists(filepath):
            augmented_qa_json = process_qa(filename, frequency_percentile=frequency_percentile)
#             write_json(augmented_qa_json, output_dir, question_set, frequency_percentile=frequency_percentile)
            filename = get_augmented_filename(output_dir, question_set, frequency_percentile=frequency_percentile)
            write_gzip_json(filename, augmented_qa_json)
        else:
            print("Skipping existing output: {}".format(filepath))

Fetching: https://ndownloader.figshare.com/files/21500109?private_link=2f119bea3e8d711047ec
Fetching: https://ndownloader.figshare.com/files/21500112?private_link=2f119bea3e8d711047ec
Fetching: https://ndownloader.figshare.com/files/21500115?private_link=2f119bea3e8d711047ec
Fetching: https://ndownloader.figshare.com/files/21500118?private_link=2f119bea3e8d711047ec
Skipping existing output: /data/augmentation/test/amazon_reviews_v1_0_roberta_JJ_VB.json.gz
Skipping existing output: /data/augmentation/test/amazon_reviews_v1_0_roberta_JJ.json.gz
Skipping existing output: /data/augmentation/test/amazon_reviews_v1_0_roberta_VB_RB.json.gz
Skipping existing output: /data/augmentation/test/amazon_reviews_v1_0_roberta_VB.json.gz
Skipping existing output: /data/augmentation/test/amazon_reviews_v1_0_roberta_RB.json.gz
Skipping existing output: /data/augmentation/test/amazon_reviews_v1_0_roberta_RB_RBR_RBZ.json.gz
Skipping existing output: /data/augmentation/test/amazon_reviews_v1_0_roberta_VB_VBD

Token indices sequence length is longer than the specified maximum sequence length for this model (735 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (676 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (572 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (653 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (626 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for thi

Writting: /data/augmentation/test/nyt_v1.0_roberta_Percentile_0.5.json.gz


In [None]:
# Mask word based on their frequency, making the lower "percentile" passed. 
def mask_sentence_frequency(sentence, percentile):
    tokenized_sentence = word_tokenize(expand_contractions(sentence))
    
    masked_words = []
    word_frequencies = {}
    
    
    for idx, word in enumerate(tokenized_sentence):
        if word in wiki_freq_dict:
            word_frequencies[word] = wiki_freq_dict[word]
            
    sorted_word_frequencies = {k: v for k, v in sorted(word_frequencies.items(), key=lambda item: item[1])}

    n=int(len(sorted_word_frequencies)*percentile)

    low_frequency_words = {k: sorted_word_frequencies[k] for k in list(sorted_word_frequencies)[:n]}

    # mask words, and keep track of what words were masked
    for i, word in enumerate(tokenized_sentence):
        if word in low_frequency_words:
            masked_words.append(word)
            tokenized_sentence[i] = model_mask

    return TreebankWordDetokenizer().detokenize(tokenized_sentence), masked_words

paragraph = "It's a very nice holder - not too big and not too small. It fits any lipstick, lip gloss, chapstick, etc nicely. I love that I'm able to see what I have and not have to dig through a makeup bag anymore. I would highly recommend."
paragraph = "First of all, this thing is freakin' awesome. My wife put it together while I did some work, she assembled all the glass shelves in to the frame. We hung it up with 6 BIG 4\" wall toggle bolts. On the top shelf I house an external WD Hdd, a WDTV Live and a Actiontec HDMI WiFi Transmitter. On the second shelf I have my Dish Network 722 (fits perfectly even with all the wires coming out the back). On the bottom shelf I have an old Ken wood Stereo Receive (fits, but the front legs come off just a bit but this doesn't affect stability. The receiver measures 17\"w x 14\"d x 5\"h.The wire maintenance was a little trick. We used a wire coat hanger and some masking tape to thread the bulky wires through the back.This unit installed is an amazing addition to your small living space, I highly recommend it!"

masked_sentence, word_mask = mask_sentence_frequency(paragraph, 0.25)
print(masked_sentence)

### View JSON format

In [None]:
qa_json = '/data/distribution_shift/new_qa/amazon_reviews_v1.0.json'
with open(qa_json) as json_file:
    qa_json = json.load(json_file)['data']
    
for splits in qa_json:
    print(json.dumps(splits, indent=4, sort_keys=True))
    break