# Baseline BART metrics and exploration

Computing rouge metrics by using the baseline BART model without any fine tuning on our test dataset. 

# Setups

In [1]:
from IPython.display import clear_output

!pip install datasets transformers rouge_score rouge-score nltk
# rouge-score is the google version
!pip install pyarrow
!pip install -q sentencepiece

clear_output()

In [2]:
import os
import re
import time
from tqdm.notebook import trange, tqdm
import pandas as pd
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt

# nlp stuff
import nltk
nltk.download('punkt')

# tf stuff
import tensorflow_datasets as tfds 
import tensorflow as tf
from transformers import PegasusTokenizer, TFPegasusForConditionalGeneration # pegasus
from transformers import BartTokenizer, TFBartForConditionalGeneration # bart

# pytorch dataset types
import datasets
from datasets.dataset_dict import DatasetDict
from datasets import Dataset, load_metric, load_dataset

# pytorch bart stuff
import torch
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer

clear_output()

In [3]:
# sign into huggingface
from huggingface_hub import notebook_login
notebook_login()

Login successful
Your token has been saved to /root/.huggingface/token
[1m[31mAuthenticated through git-credential store but this isn't the helper defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in your terminal in case you want to set this credential helper as the default

git config --global credential.helper store[0m


In [4]:
#!apt install git-lfs

Reading package lists... Done
Building dependency tree       
Reading state information... Done
git-lfs is already the newest version (2.3.4-1).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.


# Load data

In [3]:
# specify your path to the repo here:
repo_path = '/content/gdrive/MyDrive/w266/w266_reddit_summarization'

In [4]:
%%time
from google.colab import drive
drive.mount('/content/gdrive')
data_path = os.path.join(repo_path, 'data/reddit_parquet/train_test_split_v2')
os.chdir(data_path)
files = [i for i in os.listdir(data_path) if re.search("reddit", i)]

train = pd.read_parquet('reddit_train.parquet')
test = pd.read_parquet('reddit_test.parquet')
valid = pd.read_parquet('reddit_validation.parquet')

Mounted at /content/gdrive
CPU times: user 1.5 s, sys: 399 ms, total: 1.9 s
Wall time: 21.2 s


Check subreddit group counts

In [5]:
print("train")
print(train['subreddit_group'].value_counts())

print("\n\ntest:")
print(test['subreddit_group'].value_counts())

print("\n\nvalid:")
valid['subreddit_group'].value_counts()

train
advice/story              15000
gaming                    15000
media/lifestyle/sports    15000
other                     15000
Name: subreddit_group, dtype: int64


test:
advice/story              1000
gaming                    1000
media/lifestyle/sports    1000
other                     1000
Name: subreddit_group, dtype: int64


valid:


advice/story              1000
gaming                    1000
media/lifestyle/sports    1000
other                     1000
Name: subreddit_group, dtype: int64

# Modeling

Testing out several different model checkpoints and finding which produces the best summaries out-of-the-box.

In [6]:
# bunch of diff checkpoints to consider

# bart checkpoints
# model_checkpoint = 'facebook/bart-base' # keep returning the first sentence for me, extractive.
# model_checkpoint = 'facebook/bart-large-mnli' # same as above, only returns first sentences. extractive.
# model_checkpoint = 'sshleifer/distilbart-cnn-12-6' # works a bit better, but seems to produce extractive summaries still. 
model_checkpoint = 'sshleifer/distilbart-xsum-6-6' # trained on both xsum and cnn/dm, so far the best of the above. 

# pegasus checkpoints:
# model_checkpoint = "google/pegasus-xsum" # works well
# model_checkpoint = 'google/pegasus-reddit_tifu' # also works well

Load model, tokenizer, and rouge metric

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# tokenizer = PegasusTokenizer.from_pretrained(model_checkpoint) # can try this for pegasus
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
metric = load_metric("rouge")

clear_output()

In [8]:
# convert data to torch Dataset
raw_datasets = DatasetDict({
    'train': Dataset.from_dict({
        'content': train['content'],
        'summary': train['summary'],
        'subreddit': train['subreddit'],
        'subreddit_group': train['subreddit_group']
    }), 

    'test': Dataset.from_dict({
        'content': test['content'],
        'summary': test['summary'],
        'subreddit': test['subreddit'],
        'subreddit_group': test['subreddit_group']
    }), 

    'valid': Dataset.from_dict({
        'content': valid['content'],
        'summary': valid['summary'],
        'subreddit': valid['subreddit'],
        'subreddit_group': valid['subreddit_group']
    })
})

raw_datasets

DatasetDict({
    train: Dataset({
        features: ['content', 'summary', 'subreddit', 'subreddit_group'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['content', 'summary', 'subreddit', 'subreddit_group'],
        num_rows: 4000
    })
    valid: Dataset({
        features: ['content', 'summary', 'subreddit', 'subreddit_group'],
        num_rows: 4000
    })
})

In [9]:
# tokenize everything
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [doc for doc in examples["content"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)



  0%|          | 0/60 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

In [10]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

Uncomment the text to re-train. But keep it commented to just load the checkpoint and produce summaries from there. 

In [11]:
%%time

# args = Seq2SeqTrainingArguments(
#     f"bart",
#     evaluation_strategy = "epoch",
#     learning_rate=2e-5,
#     per_device_train_batch_size=4, # 16
#     per_device_eval_batch_size=4, #16
#     weight_decay=0.01,
#     save_total_limit=1,
#     num_train_epochs=1,
#     predict_with_generate=True,
#     fp16=True,
#     # push_to_hub=True,
# )

# run this to train, which we won't do at the moment
# trainer = Seq2SeqTrainer(
#     model,
#     args,
#     train_dataset=tokenized_datasets["train"],
#     eval_dataset=tokenized_datasets["valid"],
#     data_collator=data_collator,
#     tokenizer=tokenizer,
#     compute_metrics=compute_metrics
# )

# trainer.train()

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 7.39 µs


In [None]:
# optional to save to huggingface
#trainer.push_to_hub()

In [None]:
# then load model back in
# model = AutoModelForSeq2SeqLM.from_pretrained("trevorj/model_name")

In [12]:
%%time
# generate one prediction to test
ind = 4
output = model.generate(
    torch.tensor([tokenized_datasets['test']['input_ids'][ind]]),
    num_beams=2, 
    # length_penalty=0.001, # doesn't seem to do anything
    max_length=60,
    min_length=2,
    no_repeat_ngram_size=3
)

print("Input text:")
pprint(tokenized_datasets['test']['content'][ind])

print("\nTrue summary:")
pprint(tokenized_datasets['test']['summary'][ind])

summary1 = tokenizer.decode(output.squeeze(), skip_special_tokens=True)
print(f"\nPredicted summary (n words = {len(summary1.split(' '))}):")
pprint(summary1)

print('\nRouge metrics')
rouge_metrics_summary1 = metric.compute(predictions=[summary1], references=[tokenized_datasets['test']['summary'][ind]])
pprint(rouge_metrics_summary1)

Input text:
('Ugh this thread is making me cringe so much. Here is my story. I had my one '
 'and only veruca form on the heel of my right foot after stepping on a sea '
 'urchin in Cyprus. I thought that most of the spines had come out but one of '
 'the fuckers must have broken off, causing this monstrosity of nature to form '
 'around it in the weeks afterwards. Fuck it was big. I ended up gouging it '
 "out of my foot with a swiss-army penknife, leaving a decent size hole. I'll "
 'always remember looking at the roots of it and the intricate symmetrical '
 'design with little black dots in after shifting the external covering, prior '
 'to digging it out of my foot. I really feel sick writing this.')

True summary:
'Fuck sea urchins'

Predicted summary (n words = 15):
" I've been writing about the horror of having a giant veruca in the sea."

Rouge metrics
{'rouge1': AggregateScore(low=Score(precision=0.06666666666666667, recall=0.3333333333333333, fmeasure=0.1111111111111111), mid

Making preds w/ this bart model takes aboout 1.75 min for 20. Or ~5.5 sec per obs. Expect to take about 7.5 hrs to predict on 5k obs. Started at 11:04 am.
- Ended up taking 4:58 hrs (~5 hrs) on 5k obs
- round two, start at 12:03pm. 

In [15]:
%%time
# batch predict and write to disk
def model_predict(model, input_ids):
  output = model.generate(torch.tensor([input_ids]), num_beams=2, max_length=60, min_length=2, no_repeat_ngram_size=3)
  output_decoded = tokenizer.decode(output.squeeze(), skip_special_tokens=True)
  return output_decoded


from tqdm.notebook import trange, tqdm
tqdm.pandas()

df_results = pd.DataFrame({
    'content': tokenized_datasets['test']['content'],
    'y': tokenized_datasets['test']['summary'],
    'input_ids': tokenized_datasets['test']['input_ids']
})


df_results['yhat'] = df_results['input_ids'].progress_map(lambda x: model_predict(model, input_ids=x))
df_results = df_results[['content', 'y', 'yhat']]

  0%|          | 0/4000 [00:00<?, ?it/s]

CPU times: user 3h 53min 29s, sys: 1min 59s, total: 3h 55min 28s
Wall time: 3h 54min 28s


In [14]:
# maybe learn parallelizing this later, since we have 2 cores.
df_results

Unnamed: 0,content,y,yhat
0,"Living in the Sierra Nevadas, it gets very col...","Too damn cold, showered, dry my hair, dry my p...",In a series of letters from African journalis...
1,"So around midnight, after a day's worth of fis...","went hard, bubbled up",I'm a bit of a bit over the top of my bladder...
2,"Recently, I've read an article on betrayal tha...",Mandatory summary/question!,"In the wake of a woman’s affair, the BBC News..."
3,"As you all know, saturday was valentines day. ...",beer slushies 3/10 would not recommend.,"It was a very special day for me on Saturday,..."
4,Ugh this thread is making me cringe so much. H...,Fuck sea urchins,I've been writing about the horror of having ...
5,In the process of reading and learning about t...,"Could get promotion soon, losing out on overti...",I am a former employee at a small company in ...
6,Backstory: My boyfriend and I have been togeth...,I told a stupid lie that blew into a huge figh...,"In a series of letters from women, one of the..."
7,My Gf and I are currently in a LDR. We have be...,I [23 M] have been hit on/sexually advanced up...,"My Gf, I'm a 23-year-old student from London, ..."
8,"i agree. i live right next door to a skool, a...",Kids need to get out and play more! Stop codd...,I'm a bit scared of the little dumplings being...
9,Preface: I work in tech support for an ISP of ...,"Hulk, Smash!","I'm a tech support worker, and I'm not the on..."


In [16]:
%%time
# write results to disk
out_path ="/content/gdrive/MyDrive/w266/w266_reddit_summarization/data/model_outputs/bart_preds/round2/"
f1 = os.path.join(out_path, "bart_baseline_preds.parquet")
df_results.to_parquet(f1)

CPU times: user 39 ms, sys: 13 ms, total: 52 ms
Wall time: 547 ms


In [17]:
# read back in and calc results
df_results_final = pd.read_parquet(f1)
df_results_final

Unnamed: 0,content,y,yhat
0,"Living in the Sierra Nevadas, it gets very col...","Too damn cold, showered, dry my hair, dry my p...",In a series of letters from African journalis...
1,"So around midnight, after a day's worth of fis...","went hard, bubbled up",I'm a bit of a bit over the top of my bladder...
2,"Recently, I've read an article on betrayal tha...",Mandatory summary/question!,"In the wake of a woman’s affair, the BBC News..."
3,"As you all know, saturday was valentines day. ...",beer slushies 3/10 would not recommend.,"It was a very special day for me on Saturday,..."
4,Ugh this thread is making me cringe so much. H...,Fuck sea urchins,I've been writing about the horror of having ...
...,...,...,...
3995,In the US we don't have that system at all. Yo...,I'm okay with using the framework to organize ...,I'm a Spanish student and I don't think that ...
3996,"Ok, this is gonna be a ranty brain-dump to try...",American politics is a circle jerk & we've for...,"I've been talking about the case of the ""Awla..."
3997,"Hey guys, I just wanted to bring about a topic...",I don't think people should really be allowed ...,"I'm in line for a chat on Twitter, and I'm no..."
3998,"My father, who I do truly love, had some fucke...",chasing deer with dogs in NC,"I'm a hunter, but I'm not a hunter."


In [19]:
# rouge metrics:
# rouge_metrics_summary1 = metric.compute(predictions=[summary1], references=[tokenized_datasets['test']['summary'][ind]])
# pprint(rouge_metrics_summary1)
test_metrics = metric.compute(predictions=df_results_final['yhat'].tolist(), references=df_results_final['y'].tolist())
test_metrics

{'rouge1': AggregateScore(low=Score(precision=0.15663784358071894, recall=0.1507560860569759, fmeasure=0.13540118428662584), mid=Score(precision=0.16041282702773596, recall=0.15436607015450204, fmeasure=0.13802582351664905), high=Score(precision=0.16406055643179293, recall=0.15781700593144446, fmeasure=0.14068382415609923)),
 'rouge2': AggregateScore(low=Score(precision=0.019878321130524484, recall=0.019841566726755678, fmeasure=0.017355625218134593), mid=Score(precision=0.02135058834256739, recall=0.021549037430523044, fmeasure=0.0186670812006466), high=Score(precision=0.022791455756902605, recall=0.023255925582335687, fmeasure=0.019941602023524033)),
 'rougeL': AggregateScore(low=Score(precision=0.11904083274803201, recall=0.12089369088040308, fmeasure=0.10530395992246952), mid=Score(precision=0.12171198977513278, recall=0.12422892204602842, fmeasure=0.10747430877189493), high=Score(precision=0.12449516310497827, recall=0.12755177751405006, fmeasure=0.10975652888564924)),
 'rougeLsum