# Augmentation

In [1]:
!pip install torch



In [2]:
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
from pandas import concat, read_parquet, DataFrame, set_option
from data import file
from tqdm import tqdm

In [3]:
# apply progress bar on pandas operations
tqdm.pandas()
set_option('display.max_colwidth', None)

In [4]:
mname_de_en = "facebook/wmt19-de-en"
tokenizer_de_en = FSMTTokenizer.from_pretrained(mname_de_en)
model_de_en = FSMTForConditionalGeneration.from_pretrained(mname_de_en)

def translate_de_en(input):
    input_ids = tokenizer_de_en.encode(input, return_tensors="pt")
    decoded = tokenizer_de_en.decode(model_de_en.generate(input_ids)[0], skip_special_tokens=True)
    return decoded

mname_en_de = "facebook/wmt19-en-de"
tokenizer_en_de = FSMTTokenizer.from_pretrained(mname_en_de)
model_en_de = FSMTForConditionalGeneration.from_pretrained(mname_en_de)


def translate_en_de(input):
    input_ids = tokenizer_en_de.encode(input, return_tensors="pt")
    outputs = model_en_de.generate(input_ids)
    decoded = tokenizer_en_de.decode(outputs[0], skip_special_tokens=True)
    return decoded

def augment(input):
    return translate_en_de(translate_de_en(input))

In [5]:
data = read_parquet(file.news_articles_raw)


In [6]:
config = {
    'Etat': 2,
    'Inland': 2,
    'International': 1,
    'Kultur': 3,
    'Panorama': 0,
    'Sport': 1,
    'Web': 0,
    'Wirtschaft': 0,
    'Wissenschaft': 1,
}

def fake_tranlsate(input):
    return f"augmented: {input}"

list = []
for (category, count) in config.items():
    selection = data.loc[data.label==category][0:count]
    text_augmented = selection.text_original.progress_map(augment)
    list.append(DataFrame(concat([text_augmented, selection.label], axis=1)))
    

100%|██████████| 2/2 [01:08<00:00, 34.46s/it]
100%|██████████| 2/2 [01:14<00:00, 37.03s/it]
100%|██████████| 1/1 [00:22<00:00, 22.11s/it]
100%|██████████| 3/3 [01:59<00:00, 39.68s/it]
0it [00:00, ?it/s]
100%|██████████| 1/1 [00:16<00:00, 16.38s/it]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
100%|██████████| 1/1 [00:34<00:00, 34.74s/it]


In [7]:
augmented_only = DataFrame(concat(list))
augmented_only.to_parquet(path=file.news_articles_augmented_only)

In [24]:
augmented = concat([data, augmented_only])
augmented.to_parquet(path=file.news_articles_augmented)