Cornel Newsroom summarization dataset (https://summari.es/) 

In [0]:
!pip install transformers

In [2]:
import json
import re

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm  import tqdm
import torch
import numpy as np
import pandas as pd

import nltk
from nltk import tokenize
nltk.download('punkt')

import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


device(type='cuda')

# Initial Dataset

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [6]:
def clean_tqdm():
    for instance in list(tqdm._instances): 
        tqdm._decr_instances(instance)

for e in tqdm([1,2,3]):
    pass

def clean_data(text):
    text = re.sub(r'\<[^>]*\>', '', text)
    text = re.sub(r'\\n', ' ', text)
    text = re.sub(r'&[a-z]{0,7};', ' ', text)
    text = re.sub(r'\s{2,10}', ' ', text)
    text = re.sub(r'\s{2,10}', ' ', text)
    text = re.sub(r"\\'", r"'", text)
    text = re.sub(r'\\x\d{1,4}', '', text)
    return text

def get_sentences(data):
    """
    splits texts into sentences
    return: list of sentences to pass to gpt model 
            list of sentences to pass to classification model as real texts
    """
    texts_gpt2 = []
    texts_real = []

    for text in data:
        tokenized = tokenize.sent_tokenize(text)
        if len(tokenized) >= 2 and len(tokenized[0].split(' ')) < 50:
            texts_gpt2.append(tokenized[0])
            texts_real.append(' '.join(tokenized[:2]))
    assert len(texts_real) == len(texts_gpt2)
    return texts_gpt2, texts_real

100%|██████████| 3/3 [00:00<00:00, 13079.95it/s]


In [0]:
with open('/content/drive/My Drive/train-stats.jsonl', 'r') as json_file:
        json_list = list(json_file)[70001:]
        
data = []
for json_str in json_list:
        result = json.loads(json_str)
        data.append(clean_data(result['text']))
del json_list

texts_gpt2, texts_r = get_sentences(data)

In [39]:
print('Total amount of texts: {}'.format(len(texts_gpt2)))

Total amount of texts: 878615


In [40]:
texts_gpt2[:3], texts_r[:3]

(['Hewlett-Packard CEO Meg Whitman said Wednesday the company embarked on another round of layoffs in part because the technology market is changing so rapidly.',
  "It's worth noting that some of these stocks, like telecom company Frontier Communications and oil and gas driller Helmerich & Payne, have sky-high valuations because earnings expectations are nearly nil; the story for oil exploration and production company ConocoPhillips is similar.",
  "TAIPEI, Sept 10 (Reuters) - Computer maker Dell Inc will invest $125 billion in China over the next five years, its chief executive said on Thursday, as the company continues to expand in the world's second-largest economy."],
 ['Hewlett-Packard CEO Meg Whitman said Wednesday the company embarked on another round of layoffs in part because the technology market is changing so rapidly. "It\'s remarkable what\'s happening to our services business.',
  "It's worth noting that some of these stocks, like telecom company Frontier Communications 

In [41]:
texts_gpt2[-3:], texts_r[-3:]

(['Elizabeth Taylor has White Diamonds.',
  'BALTIMORE, May 18 -- A disease believed to be equine herpes virus has swept through the barn area at Churchill Downs, site of the Kentucky Derby, leading to the death of two horses and the placement of a quarantine on three barns.',
  'Columnist Michelle Singletary was online to field questions about everything from retirement planning to protecting your credit rating.'],
 ['Elizabeth Taylor has White Diamonds. Coco Chanel had Chanel No.',
  'BALTIMORE, May 18 -- A disease believed to be equine herpes virus has swept through the barn area at Churchill Downs, site of the Kentucky Derby, leading to the death of two horses and the placement of a quarantine on three barns. The outbreak of the rare neurological virus, which can cause symptoms ranging from mild fever and upper respiratory infection to paralysis, has led to the scratching of three horses scheduled to run this weekend at Pimlico in major stakes races.',
  'Columnist Michelle Singlet

# Generation 

In [0]:
MAX_LENGTH = 50
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<pad>', max_length=500)
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(device);

In [0]:
def generate_fake(texts, sampling_type, tokenizer=tokenizer, model=model):
    """
    Generates texts depending on sampling_type. 
    sampling type: tuple (sampling_type : value)
    return: list of fake texts
    """
    fake = []
    samplings = []

    for el in tqdm(texts):    
        sent = generate_sentence(el, model, tokenizer, sampling_type)            
        sent = re.sub(r'\n', ' ', sent)
        fake.append(sent)
        samplings.append(sampling_type[0])        
    return fake, samplings

def generate_sentence(sentence, model, tokenizer, sampling_type, max_length=MAX_LENGTH):    
    """
    Generates sentence depending on sampling_type
    return: str
    """
    eos = tokenizer.encode('.?!...! ?')
    context = torch.tensor([tokenizer.encode(sentence)][:500]).to(device)
    max_length += context.size()[-1]
    
    if sampling_type[0] == 'beam_search':
        outputs = model.generate(input_ids=context, max_length=max_length,
                                 do_sample=True, num_beams=sampling_type[1],
                                 pad_token_id=tokenizer.pad_token_id,
                                 repetition_penalty=2.3)
    elif sampling_type[0] == 'temperature':
        outputs = model.generate(input_ids=context, max_length=max_length,
                                 do_sample=True, temperature=sampling_type[1],
                                 pad_token_id=tokenizer.pad_token_id)
    elif sampling_type[0] == 'top_k':
        outputs = model.generate(input_ids=context, max_length=max_length,
                                 do_sample=True, temperature=sampling_type[1],
                                 pad_token_id=tokenizer.pad_token_id)
    elif sampling_type[0] == 'nucleus':
        outputs = model.generate(input_ids=context, max_length=max_length,
                                 do_sample=True, top_p=sampling_type[1],
                                 pad_token_id=tokenizer.pad_token_id)
    else:  # Argmax otherwise
        outputs = model.generate(input_ids=context, max_length=max_length, do_sample=False,
                                 pad_token_id=tokenizer.pad_token_id)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def df_checkpoint(texts_fake, texts_real, samplings):
    """
    creates pandas DataFrame with texts (both real and fake) and sampling type
    saves df to google drive and working directory
    return: pandas DataFrame
    """
    df = pd.DataFrame(columns=['text', 'label', 'sampling'])
    df['text'] = texts_fake + texts_real
    df['label'] = ['fake']* len(texts_fake) + ['real']* len(texts_real)
    df['sampling'] = samplings  + ['No sampling'] * len(texts_real)
    df.to_csv('dataset.csv', index=False)
    df.to_csv('/content/drive/My Drive/dpl_dataset.csv', index=False)
    return df

In [0]:
# If something crashes

df = pd.read_csv('/content/drive/My Drive/dpl_dataset.csv')
samplings = list(df[df['sampling'] != 'No sampling']['sampling'])
texts_fake = list(df[df['label'] == 'fake']['text'])
texts_real = list(df[df['label'] == 'real']['text'])

In [42]:
len(texts_fake), len(texts_real), len(samplings)

(70000, 70000, 70000)

So far I didn't use beam_search as it takes a lot of time to generate texts.

In [0]:
n = 10000  # Number of texts to generate for every sampling type (70k total for each sampling type)
# sampling_types = [('temperature', 0.9), ('temperature', 0.8), ('top_k', 20), ('top_k', 100), ('nucleus', 0.9), ('nucleus', 0.8), ('argmax', 23), ('beam_search', 3), ('beam_search',5)]
# texts_fake = []
# texts_real = []
# samplings = []
sampling_types = [('temperature', 0.9), ('temperature', 0.8), ('top_k', 20), ('top_k', 100), ('nucleus', 0.9), ('nucleus', 0.8), ('argmax', 23)]


for ind, sampling_type in enumerate(sampling_types):
    clean_tqdm()
    print('Starting generation for {}'.format(sampling_type))
    fake, sampling  = generate_fake(texts_gpt2[ind*n:ind*n+n], sampling_type, tokenizer=tokenizer,model=model)
    texts_fake.extend(fake)
    samplings.extend(sampling)
    texts_real.extend(texts_r[ind*n:ind*n+n])

    df = df_checkpoint(texts_fake, texts_real, samplings)
    print('DataFrame saved')

  1%|          | 69/10000 [26:57<64:40:12, 23.44s/it]
  0%|          | 0/10000 [00:00<?, ?it/s]

Starting generation for ('temperature', 0.9)


100%|██████████| 10000/10000 [1:35:01<00:00,  1.75it/s]
  0%|          | 0/10000 [00:00<?, ?it/s]

DataFrame saved
Starting generation for ('temperature', 0.8)


 97%|█████████▋| 9748/10000 [1:33:29<02:09,  1.94it/s]

In [27]:
df.head()

Unnamed: 0,text,label,sampling
0,"HAMBURG, Germany, June 3  As he left the socc...",fake,temperature
1,"WASHINGTON, Dec. 23 - The National Security Ag...",fake,temperature
2,IF outsized executive pay has indeed become a ...,fake,temperature
3,"BY A.J. Miller, Jr. The three men will make t...",fake,temperature
4,Spinach has terrorized generations of veggie-p...,fake,temperature


In [28]:
df.tail()

Unnamed: 0,text,label,sampling
59995,"WASHINGTON, July 11— Barred by a recent Suprem...",real,No sampling
59996,"WASHINGTON, Sept. 12— Justice Department offic...",real,No sampling
59997,The army of prisoners grows at an alarming pac...,real,No sampling
59998,Things are really jumping at the Queens Museum...,real,No sampling
59999,"Now that I am an ordinary citizen again, permi...",real,No sampling


In [23]:
assert len(texts_fake) == len(texts_real)
texts_fake[:5], texts_real[:5]

(["HAMBURG, Germany, June 3 \x97 As he left the soccer field after a club match in the eastern German city of Halle on March 25, the Nigerian forward Adebowale Ogungbure was spit upon, jeered with racial remarks and mocked with monkey noises. The incident, which took place in front of the national team's training facility in Dortmund, Germany, the day after the World Cup, drew national attention to the country's military culture and practices. More than 1,400 protesters took to the streets of",
  'WASHINGTON, Dec. 23 - The National Security Agency has traced and analyzed large volumes of telephone and Internet communications flowing into and out of the United States as part of the eavesdropping program that President Bush approved after the Sept. 11, 2001, attacks to hunt for evidence of terrorist activity, according to current and former government officials.  In addition, the National Security Agency has collected hundreds of millions of telephone and wire communications during the t