Cornel Newsroom summarization dataset (https://summari.es/) 

In [0]:
!pip install transformers

In [0]:
import json
import re

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm  import tqdm
import torch
import numpy as np
import pandas as pd

import nltk
from nltk import tokenize
nltk.download('punkt')

import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


device(type='cuda')

# Initial Dataset

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
def clean_tqdm():
    for instance in list(tqdm._instances): 
        tqdm._decr_instances(instance)

for e in tqdm([1,2,3]):
    pass

def clean_data(text):
    text = re.sub(r'\<[^>]*\>', '', text)
    text = re.sub(r'\\n', ' ', text)
    text = re.sub(r'&[a-z]{0,7};', ' ', text)
    text = re.sub(r'\s{2,10}', ' ', text)
    text = re.sub(r'\s{2,10}', ' ', text)
    text = re.sub(r"\\'", r"'", text)
    text = re.sub(r'\\x\d{1,4}', '', text)
    return text

def get_sentences(data):
    """
    splits texts into sentences
    return: list of sentences to pass to gpt model 
            list of sentences to pass to classification model as real texts
    """
    texts_gpt2 = []
    texts_real = []

    for text in data:
        tokenized = tokenize.sent_tokenize(text)
        if len(tokenized) >= 2 and len(tokenized[0].split(' ')) < 50:
            texts_gpt2.append(tokenized[0])
            texts_real.append(' '.join(tokenized[:2]))
    assert len(texts_real) == len(texts_gpt2)
    return texts_gpt2, texts_real

100%|██████████| 3/3 [00:00<00:00, 4064.25it/s]


In [0]:
with open('/content/drive/My Drive/train-stats.jsonl', 'r') as json_file:
        json_list = list(json_file)[220001:]
        
data = []
for json_str in json_list:
        result = json.loads(json_str)
        data.append(clean_data(result['text']))
del json_list

texts_gpt2, texts_r = get_sentences(data)

In [0]:
print('Total amount of texts: {}'.format(len(texts_gpt2)))

Total amount of texts: 735500


In [0]:
texts_gpt2[:3], texts_r[:3]

(['Did you start your morning by preparing a strong cup of coffee?',
  'India is scrambling to protect its beleaguered tiger population after several big cats tested positive for a virus common among dogs but deadly to other carnivores, experts said.',
  "The new album from Bruce Springsteen, High Hopes, is a collection of 12 rare, previously unreleased and cover songs that, in some cases, date as far back as the '90s."],
 ['Did you start your morning by preparing a strong cup of coffee? Chances are you were also brewing an effective long-term memory tonic, according to researchers with Johns Hopkins University.',
  "India is scrambling to protect its beleaguered tiger population after several big cats tested positive for a virus common among dogs but deadly to other carnivores, experts said. In the last year, canine distemper virus has killed at least four tigers and several other animals across northern and eastern India, according to Rajesh Gopal of the government's National Tiger C

In [0]:
texts_gpt2[-3:], texts_r[-3:]

(['Elizabeth Taylor has White Diamonds.',
  'BALTIMORE, May 18 -- A disease believed to be equine herpes virus has swept through the barn area at Churchill Downs, site of the Kentucky Derby, leading to the death of two horses and the placement of a quarantine on three barns.',
  'Columnist Michelle Singletary was online to field questions about everything from retirement planning to protecting your credit rating.'],
 ['Elizabeth Taylor has White Diamonds. Coco Chanel had Chanel No.',
  'BALTIMORE, May 18 -- A disease believed to be equine herpes virus has swept through the barn area at Churchill Downs, site of the Kentucky Derby, leading to the death of two horses and the placement of a quarantine on three barns. The outbreak of the rare neurological virus, which can cause symptoms ranging from mild fever and upper respiratory infection to paralysis, has led to the scratching of three horses scheduled to run this weekend at Pimlico in major stakes races.',
  'Columnist Michelle Singlet

# Generation 

In [0]:
MAX_LENGTH = 50
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<pad>', max_length=500)
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(device);

HBox(children=(IntProgress(value=0, description='Downloading', max=1042301, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Downloading', max=456318, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Downloading', max=224, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=548118077, style=ProgressStyle(description_…




In [0]:
def generate_fake(texts, sampling_type, tokenizer=tokenizer, model=model):
    """
    Generates texts depending on sampling_type. 
    sampling type: tuple (sampling_type : value)
    return: list of fake texts
    """
    fake = []
    samplings = []

    for el in tqdm(texts):    
        sent = generate_sentence(el, model, tokenizer, sampling_type)            
        sent = re.sub(r'\n', ' ', sent)
        fake.append(sent)
        samplings.append(sampling_type[0])        
    return fake, samplings

def generate_sentence(sentence, model, tokenizer, sampling_type, max_length=MAX_LENGTH):    
    """
    Generates sentence depending on sampling_type
    return: str
    """
    eos = tokenizer.encode('.?!...! ?')
    context = torch.tensor([tokenizer.encode(sentence)][:500]).to(device)
    max_length += context.size()[-1]
    
    if sampling_type[0] == 'beam_search':
        outputs = model.generate(input_ids=context, max_length=max_length,
                                 do_sample=True, num_beams=sampling_type[1],
                                 pad_token_id=tokenizer.pad_token_id,
                                 repetition_penalty=2.3)
    elif sampling_type[0] == 'temperature':
        outputs = model.generate(input_ids=context, max_length=max_length,
                                 do_sample=True, temperature=sampling_type[1],
                                 pad_token_id=tokenizer.pad_token_id)
    elif sampling_type[0] == 'top_k':
        outputs = model.generate(input_ids=context, max_length=max_length,
                                 do_sample=True, temperature=sampling_type[1],
                                 pad_token_id=tokenizer.pad_token_id)
    elif sampling_type[0] == 'nucleus':
        outputs = model.generate(input_ids=context, max_length=max_length,
                                 do_sample=True, top_p=sampling_type[1],
                                 pad_token_id=tokenizer.pad_token_id)
    else:  # Argmax otherwise
        outputs = model.generate(input_ids=context, max_length=max_length, do_sample=False,
                                 pad_token_id=tokenizer.pad_token_id)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def df_checkpoint(texts_fake, texts_real, samplings):
    """
    creates pandas DataFrame with texts (both real and fake) and sampling type
    saves df to google drive and working directory
    return: pandas DataFrame
    """
    df = pd.DataFrame(columns=['text', 'label', 'sampling'])
    df['text'] = texts_fake + texts_real
    df['label'] = ['fake']* len(texts_fake) + ['real']* len(texts_real)
    df['sampling'] = samplings  + ['No sampling'] * len(texts_real)
    df.to_csv('dataset.csv', index=False)
    df.to_csv('/content/drive/My Drive/dpl_dataset.csv', index=False)
    return df

In [0]:
# If something crashes

df = pd.read_csv('/content/drive/My Drive/dpl_dataset.csv')
samplings = list(df[df['sampling'] != 'No sampling']['sampling'])
texts_fake = list(df[df['label'] == 'fake']['text'])
texts_real = list(df[df['label'] == 'real']['text'])

In [19]:
len(texts_fake), len(texts_real), len(samplings)

(260000, 260000, 260000)

In [0]:
n = 10000  # Number of texts to generate for every sampling type (70k total for each sampling type)
# sampling_types = [('temperature', 0.9), ('temperature', 0.8), ('top_k', 20), ('top_k', 100), ('nucleus', 0.9), ('nucleus', 0.8), ('argmax', 23), ('beam_search', 3), ('beam_search',5)]
sampling_types = [('temperature', 0.9), ('temperature', 0.8), ('top_k', 20), ('top_k', 100), ('nucleus', 0.9), ('nucleus', 0.8), ('argmax', 23)]


for ind, sampling_type in enumerate(sampling_types):
    clean_tqdm()
    print('Starting generation for {}'.format(sampling_type))
    fake, sampling  = generate_fake(texts_gpt2[ind*n:ind*n+n], sampling_type, tokenizer=tokenizer,model=model)
    texts_fake.extend(fake)
    samplings.extend(sampling)
    texts_real.extend(texts_r[ind*n:ind*n+n])

    df = df_checkpoint(texts_fake, texts_real, samplings)
    print('DataFrame saved')

In [15]:
df.head()

Unnamed: 0,text,label,sampling
0,"HAMBURG, Germany, June 3  As he left the socc...",fake,temperature
1,"WASHINGTON, Dec. 23 - The National Security Ag...",fake,temperature
2,IF outsized executive pay has indeed become a ...,fake,temperature
3,"BY A.J. Miller, Jr. The three men will make t...",fake,temperature
4,Spinach has terrorized generations of veggie-p...,fake,temperature


In [16]:
df.tail()

Unnamed: 0,text,label,sampling
519995,"With a bit of hardware, your iPhone can become...",real,No sampling
519996,The Centers for Disease Control and Prevention...,real,No sampling
519997,The number of new HIV infections in the United...,real,No sampling
519998,AT&T says the throttling will only be temporar...,real,No sampling
519999,"While you're frying up some eggs and bacon, we...",real,No sampling


In [17]:
assert len(texts_fake) == len(texts_real)
texts_fake[-3:], texts_real[-3:]

(['The number of new HIV infections in the United States has remained steady, at around 50,000 cases a year over the past four years, according to the Center for Disease Control and Prevention. Most people have only ever made personal comments through public channels which the FBI will no different to record now given his political agenda but can go on after such communications do something dangerous — if one looks deeply inside or off him at an establishment conference that many in',
  'AT&T says the throttling will only be temporary; full speed will be restored to throttled customers at the start of the each billing cycle. If someone in our "procedy line on request group pays 1 million euros without signing one petition from any channel but us that goes in 30.9 min", says that he may end Up In Line right away during every month until he notices for',
  'While you\'re frying up some eggs and bacon, we\'re cooking up something else: a way to celebrate today\'s food holiday. "A meal wit