<a href="https://colab.research.google.com/github/sayarghoshroy/Augment4Gains/blob/main/pegasus_paraphrase.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Using a pre-trained Transformer Encoder-Decoder based Paraphraser

In [2]:
%%capture

# Getting necessary libraries
!pip install -U transformers
!pip install sentencepiece

import os.path
from os import path
import json
import torch
import nltk
import sentencepiece
from tqdm import tqdm
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, BertTokenizer

In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
# Using a standard model
tokenizer_model_name = 'google/pegasus-large'

# Using a pre-trained community model
paraphrasing_model_name = 'tuner007/pegasus_paraphrase'

torch_device = 'cuda'
if torch.cuda.is_available() == False:
  torch_device = 'cpu'

tokenizer = PegasusTokenizer.from_pretrained(tokenizer_model_name)
model = PegasusForConditionalGeneration.from_pretrained(paraphrasing_model_name).to(torch_device)

In [5]:
def get_paraphrase(input_text, num_return_sequences = 2, num_beams = 10):
  preprocess = True
  preprocess_len = 52
  max_len = 60

  if preprocess == True:
    token_count = 0
    processed_text = ''
    
    sentences = nltk.sent_tokenize(input_text)
    for sentence in sentences:
      tokens = nltk.word_tokenize(sentence)
      count = len(tokens)
      token_count += count
      if token_count > preprocess_len:
        break
      for token in tokens:
        processed_text += (token + ' ')
  else:
    processed_text = input_text

  batch = tokenizer([processed_text], 
                    truncation = True,
                    padding = 'longest',
                    max_length = max_len,
                    return_tensors = 'pt').to(torch_device)

  translated = model.generate(**batch,
                              max_length = max_len,
                              num_beams = num_beams,
                              num_return_sequences = num_return_sequences,
                              temperature = 1.5)
  
  target = tokenizer.batch_decode(translated,
                                  skip_special_tokens = True)
  
  return target

In [6]:
# Viewing Sample Paraphrases

examples = ['you should watch louis le vau \'s latest video . steven oh of tyt is disturbing as hell and makes me hope that jimmy dore wakes the left up .',
            'kill yourself you whiny , self-righteous faggot .',
            'but why do they make that face']

for example in examples:
  print('Source: ' + str(example))
  responses = get_paraphrase(example)
  print('Primary Paraphrase: ' + str(responses[0]))
  print()

Source: you should watch louis le vau 's latest video . steven oh of tyt is disturbing as hell and makes me hope that jimmy dore wakes the left up .
Primary Paraphrase: louis le vau's latest video is disturbing and makes me hope that jimmy dore wakes the left up.

Source: kill yourself you whiny , self-righteous faggot .
Primary Paraphrase: You are self-righteous and should kill yourself.

Source: but why do they make that face
Primary Paraphrase: Why do they make that face?



In [7]:
# Generating the Augmented Training Data
set_type = 'gab'

# Reference to the absolute path in Google Drive
data_path = 'drive/My Drive/Augment4Gains/data/' + set_type

with open(data_path + '/' + 'train.json', 'r+') as f:
  raw_train = json.load(f)

In [8]:
# Getting the Augmented Datapoints
augmented_data = []
overwrite_data = True
save_name = data_path + '/' + 'paraphrased_train.json'
limit = len(raw_train)
minimum_length = 4
interval = 500

test_mode = True
if test_mode == True:
  interval = 2
  limit = 100

backup_present = path.isfile(save_name)
done = 0

if backup_present:
  print('Pre-processed Data Backup Found: ' + str(backup_present), flush = True)
  with open(save_name, 'r+') as f:
    augmented_data = json.load(f)
  done = len(augmented_data)
  print('Starting from ' + str(done) + ' onwards.', flush = True)

for index in tqdm(range(done, limit)):
  unit = raw_train[index]
  if index > limit - 1:
    break

  try:
    raw_text = str(unit['source'].replace('\n', ' '))
    target = get_paraphrase(raw_text)[0]
  except:
    pass
    continue
  
  token_count = len(nltk.word_tokenize(target))
  if token_count < minimum_length:
    continue
  new_unit = unit.copy()
  
  if 'type' in new_unit:
    new_unit.pop('type')
  if 'set' in new_unit:
    new_unit.pop('set')

  new_unit['source'] = target
  augmented_data.append(new_unit)

  if index % interval == 0 and overwrite_data == True:
      with open(save_name, 'w+') as f:
        json.dump(augmented_data, f)

if overwrite_data == True:
  with open(save_name, 'w+') as f:
    json.dump(augmented_data, f)

Pre-processed Data Backup Found: True
Starting from 74 onwards.


100%|██████████| 26/26 [00:12<00:00,  2.16it/s]


In [9]:
len(augmented_data)

100

In [10]:
# ^_^ Thank You