In [None]:
import nltk
from googletrans import Translator
from nltk.corpus import wordnet
import nlpaug.augmenter.word as naw
from datasets import load_dataset
from tqdm import tqdm
import pandas as pd

nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')

cnn_dataset = load_dataset("abisee/cnn_dailymail", "1.0.0", split='train')
newsqa_dataset = load_dataset("glnmario/news-qa-summarization", split='train')

def get_synonym(word):
    synsets = wordnet.synsets(word)
    if synsets:
        synset = synsets[0]
        synonyms = synset.lemma_names()
        filtered_synonyms = [syn for syn in synonyms if syn.lower() != word.lower()]
        if filtered_synonyms:
            return filtered_synonyms[0]
    return word

def augment_with_wordnet(text):
    return " ".join([get_synonym(word) for word in text.split()])

nlpaug_augmenter = naw.SynonymAug(aug_p=0.1)

def back_translate(text, src_lang='en', tgt_lang='es'):
    translator = Translator()
    translated = translator.translate(text, src=src_lang, dest=tgt_lang).text
    back_translated = translator.translate(translated, src=tgt_lang, dest=src_lang).text
    return back_translated

def augment_text(text):
    back_translated_text = back_translate(text)
    wordnet_augmented_text = augment_with_wordnet(back_translated_text)
    nlpaug_augmented_text = nlpaug_augmenter.augment(wordnet_augmented_text)
    return back_translated_text, wordnet_augmented_text, nlpaug_augmented_text

augmented_data_back_translation = []
augmented_data_wordnet = []
augmented_data_nlpaug = []

for i in tqdm(range(len(cnn_dataset))):
    example_cnn = cnn_dataset[i]['article']
    cnn_back_translated, cnn_wordnet_augmented, cnn_nlpaug_augmented = augment_text(example_cnn)

    example_newsqa = newsqa_dataset[i]['story']
    newsqa_back_translated, newsqa_wordnet_augmented, newsqa_nlpaug_augmented = augment_text(example_newsqa)

    augmented_data_back_translation.append({
        'cnn_original': example_cnn,
        'cnn_back_translated': cnn_back_translated,
        'newsqa_original': example_newsqa,
        'newsqa_back_translated': newsqa_back_translated
    })

    augmented_data_wordnet.append({
        'cnn_back_translated': cnn_back_translated,
        'cnn_wordnet_augmented': cnn_wordnet_augmented,
        'newsqa_back_translated': newsqa_back_translated,
        'newsqa_wordnet_augmented': newsqa_wordnet_augmented
    })

    augmented_data_nlpaug.append({
        'cnn_wordnet_augmented': cnn_wordnet_augmented,
        'cnn_nlpaug_augmented': cnn_nlpaug_augmented,
        'newsqa_wordnet_augmented': newsqa_wordnet_augmented,
        'newsqa_nlpaug_augmented': newsqa_nlpaug_augmented
    })

df_back_translation = pd.DataFrame(augmented_data_back_translation)
df_wordnet = pd.DataFrame(augmented_data_wordnet)
df_nlpaug = pd.DataFrame(augmented_data_nlpaug)

df_back_translation.to_csv("/Users/fanyuanhao/python/path/to/save/augmented_dataset_fixed/cnn_newsqa_back_translation.csv", index=False)
df_wordnet.to_csv("/Users/fanyuanhao/python/path/to/save/augmented_dataset_fixed/cnn_newsqa_wordnet_augmented.csv", index=False)
df_nlpaug.to_csv("/Users/fanyuanhao/python/path/to/save/augmented_dataset_fixed/cnn_newsqa_nlpaug_augmented.csv", index=False)

print("All augmented data saved successfully.")
