# Compare metrics between the models

In [None]:
from IPython.display import clear_output

!pip install datasets transformers rouge_score nltk
# !pip install datasets transformers rouge-score nltk
# rouge-score is the google version
!pip install pyarrow
!pip install -q sentencepiece

clear_output()

In [None]:
from IPython.display import clear_output
import os
import re
import time
from tqdm.notebook import trange, tqdm
import pandas as pd
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt

# nlp stuff
import nltk
nltk.download('punkt')

# tf stuff
import tensorflow_datasets as tfds 
import tensorflow as tf
from transformers import PegasusTokenizer, TFPegasusForConditionalGeneration # pegasus
from transformers import BartTokenizer, TFBartForConditionalGeneration # bart

# pytorch dataset types
import datasets
from datasets.dataset_dict import DatasetDict
from datasets import Dataset, load_metric, load_dataset

# pytorch bart stuff
import torch
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer

# clear_output()

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
# sign into huggingface
from huggingface_hub import notebook_login
notebook_login()

Login successful
Your token has been saved to /root/.huggingface/token
[1m[31mAuthenticated through git-credential store but this isn't the helper defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in your terminal in case you want to set this credential helper as the default

git config --global credential.helper store[0m


In [None]:
!apt install git-lfs

Reading package lists... Done
Building dependency tree       
Reading state information... Done
git-lfs is already the newest version (2.3.4-1).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.


In [None]:
# about this metric: https://huggingface.co/spaces/evaluate-metric/rouge
metric = load_metric("rouge")

Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

In [None]:
# try a diff package for rouge
# !pip install rouge-score # google package version
# from rouge_score import rouge_scorer
# scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rouge3', 'rougeL'], use_stemmer=True)
#scores = scorer.score(target=y, prediction=yhat)
#scores

In [None]:
# specify your path to the repo here:
repo_path = '/content/gdrive/MyDrive/w266/w266_reddit_summarization'

from google.colab import drive
drive.mount('/content/gdrive')

# baseline bart
baseline_preds = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_baseline_preds.parquet'))
baseline_preds.columns = ['content', 'y', 'yhat_baseline']
baseline_preds = baseline_preds.sort_values(['content', 'y'])
baseline_preds.reset_index(drop=True, inplace=True)

# finetuned
df1 = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_preds_advice.parquet'))
df1['subreddit_group'] = 'advice_story'
df2 = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_preds_media.parquet'))
df2['subreddit_group'] = 'media_lifestyle_sports'
df3 = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_preds_gaming.parquet'))
df3['subreddit_group'] = 'gaming'
df4 = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_preds_other.parquet'))
df4['subreddit_group'] = 'other'

# finetuned bart
finetuned_preds = pd.concat([df1, df2, df3, df4], ignore_index=True)
finetuned_preds.columns = ['content', 'y', 'yhat_finetune', 'subreddit_group']
finetuned_preds = finetuned_preds.sort_values(['content', 'y'])
finetuned_preds.reset_index(drop=True, inplace=True)

Mounted at /content/gdrive


In [None]:
finetuned_preds['y'].iloc[99]

'Sleepovers are socially important and the logic behind discouraging them is flawed.'

In [None]:
finetuned_preds['yhat_finetune'].iloc[99]

"It's not a bad idea to try to ban them, but it's a very good idea to make sure they don't violate the law of chastity."

In [None]:
df = pd.merge(baseline_preds, finetuned_preds[['yhat_finetune', 'subreddit_group']], left_index=True, right_index=True, how='inner')

# Metrics

In [None]:
# Advice/Story
baseline_metrics = metric.compute(predictions=df[df['subreddit_group']=='advice_story']['yhat_baseline'].tolist(), references=df[df['subreddit_group']=='advice_story']['y'].tolist())
finetune_metrics = metric.compute(predictions=df[df['subreddit_group']=='advice_story']['yhat_finetune'].tolist(), references=df[df['subreddit_group']=='advice_story']['y'].tolist())

print("Baseline:")
pprint(baseline_metrics)

print("\n\nFinetuned:")
pprint(finetune_metrics)

Baseline:
{'rouge1': AggregateScore(low=Score(precision=0.16094365877617195, recall=0.1602015439870981, fmeasure=0.1425687666504958), mid=Score(precision=0.16889544982043658, recall=0.16741322678947881, fmeasure=0.14854337359613046), high=Score(precision=0.1764848050788522, recall=0.175687995232487, fmeasure=0.15397653185046078)),
 'rouge2': AggregateScore(low=Score(precision=0.02205258959722962, recall=0.02231012226302586, fmeasure=0.01972740526675428), mid=Score(precision=0.024659129999232084, recall=0.025384493175649264, fmeasure=0.02194717860461709), high=Score(precision=0.027439046554109333, recall=0.02917227366980293, fmeasure=0.02461692288625782)),
 'rougeL': AggregateScore(low=Score(precision=0.12278745268646098, recall=0.12873991552551745, fmeasure=0.11111018760538742), mid=Score(precision=0.1284271189684649, recall=0.1351269270902865, fmeasure=0.11540917192058497), high=Score(precision=0.1340352999815163, recall=0.14205069261126888, fmeasure=0.11961493728653738)),
 'rougeLsum

In [None]:
# media_lifestyle_sports
baseline_metrics = metric.compute(predictions=df[df['subreddit_group']=='media_lifestyle_sports']['yhat_baseline'].tolist(), references=df[df['subreddit_group']=='media_lifestyle_sports']['y'].tolist())
finetune_metrics = metric.compute(predictions=df[df['subreddit_group']=='media_lifestyle_sports']['yhat_finetune'].tolist(), references=df[df['subreddit_group']=='media_lifestyle_sports']['y'].tolist())

print("Baseline:")
pprint(baseline_metrics)

print("\n\nFinetuned:")
pprint(finetune_metrics)

Baseline:
{'rouge1': AggregateScore(low=Score(precision=0.13782575309369988, recall=0.14027384417329705, fmeasure=0.12157127106876033), mid=Score(precision=0.14577294570110721, recall=0.14801789946887003, fmeasure=0.12737209098262225), high=Score(precision=0.15383363030028804, recall=0.15671603232320702, fmeasure=0.132857328502763)),
 'rouge2': AggregateScore(low=Score(precision=0.016075028497345028, recall=0.018007252330367388, fmeasure=0.014456667499547646), mid=Score(precision=0.018505825269080584, recall=0.02118951752691096, fmeasure=0.016605443944585692), high=Score(precision=0.02107179422859783, recall=0.024667925171563593, fmeasure=0.018874622775641083)),
 'rougeL': AggregateScore(low=Score(precision=0.10697203905821571, recall=0.11376547402441321, fmeasure=0.09568433903364373), mid=Score(precision=0.11235333710196388, recall=0.12077124817147877, fmeasure=0.1004473743565866), high=Score(precision=0.11824752414577779, recall=0.12829435075819123, fmeasure=0.10527086380645524)),
 '

In [None]:
# gaming
baseline_metrics = metric.compute(predictions=df[df['subreddit_group']=='gaming']['yhat_baseline'].tolist(), references=df[df['subreddit_group']=='gaming']['y'].tolist())
finetune_metrics = metric.compute(predictions=df[df['subreddit_group']=='gaming']['yhat_finetune'].tolist(), references=df[df['subreddit_group']=='gaming']['y'].tolist())

print("Baseline:")
pprint(baseline_metrics)

print("\n\nFinetuned:")
pprint(finetune_metrics)

Baseline:
{'rouge1': AggregateScore(low=Score(precision=0.1583076426740779, recall=0.1455572047226279, fmeasure=0.13481455440039594), mid=Score(precision=0.16626301978268632, recall=0.15273371785541023, fmeasure=0.14077543297673037), high=Score(precision=0.17451586472008315, recall=0.1598216468758118, fmeasure=0.14650242049789533)),
 'rouge2': AggregateScore(low=Score(precision=0.01804176302054014, recall=0.0161626554970315, fmeasure=0.015330705785555414), mid=Score(precision=0.02089377746991731, recall=0.01928926478751475, fmeasure=0.017824681758376305), high=Score(precision=0.023782090083779187, recall=0.02214126921517091, fmeasure=0.02024542975189972)),
 'rougeL': AggregateScore(low=Score(precision=0.11846574250429684, recall=0.11436087717361908, fmeasure=0.10320205244458239), mid=Score(precision=0.12386880828432575, recall=0.11995272243630092, fmeasure=0.10744931477207631), high=Score(precision=0.12987693089744148, recall=0.12580379407333572, fmeasure=0.11217104223821041)),
 'rouge

In [None]:
# other
baseline_metrics = metric.compute(predictions=df[df['subreddit_group']=='other']['yhat_baseline'].tolist(), references=df[df['subreddit_group']=='other']['y'].tolist())
finetune_metrics = metric.compute(predictions=df[df['subreddit_group']=='other']['yhat_finetune'].tolist(), references=df[df['subreddit_group']=='other']['y'].tolist())

print("Baseline:")
pprint(baseline_metrics)

print("\n\nFinetuned:")
pprint(finetune_metrics)

Baseline:
{'rouge1': AggregateScore(low=Score(precision=0.15227905937171815, recall=0.14224619769835914, fmeasure=0.13075502626266045), mid=Score(precision=0.16033535781674657, recall=0.14950648042892578, fmeasure=0.13610454806146136), high=Score(precision=0.16812832637649694, recall=0.1566605877306079, fmeasure=0.1419170725530447)),
 'rouge2': AggregateScore(low=Score(precision=0.01862611420475025, recall=0.017075739643976302, fmeasure=0.015869110236825353), mid=Score(precision=0.021397063746590672, recall=0.020261734525464897, fmeasure=0.018293209479671307), high=Score(precision=0.024073885991680876, recall=0.023805415851009912, fmeasure=0.02059356581221012)),
 'rougeL': AggregateScore(low=Score(precision=0.11669371572233089, recall=0.11454710937865369, fmeasure=0.10197650490250112), mid=Score(precision=0.12227675862101661, recall=0.12085570297016784, fmeasure=0.10651167576808727), high=Score(precision=0.12848320573587912, recall=0.1271284072991241, fmeasure=0.11084877768989151)),
 '

# Word counts in each subgroup

In [None]:
full_df = pd.concat([
  pd.read_parquet(os.path.join(repo_path, 'data/reddit_parquet/train_test_split_v2/reddit_train.parquet')), 
  pd.read_parquet(os.path.join(repo_path, 'data/reddit_parquet/train_test_split_v2/reddit_test.parquet')), 
  pd.read_parquet(os.path.join(repo_path, 'data/reddit_parquet/train_test_split_v2/reddit_validation.parquet'))], 
  ignore_index=True)

full_df

Unnamed: 0,content,summary,subreddit,subreddit_group
0,Can bars and/or clubs legally charge men for e...,Is it legal for bars and/or clubs to charge me...,legaladvice,advice/story
1,Great Link! Enjoyed the comments at the end es...,check out the link,WTF,advice/story
2,"Met my bf on DeviantArt, oddly enough. I comme...",since there was of course WAY more involved in...,SRSQuestions,advice/story
3,You told me you really liked me. \n When I mov...,"You seduced me when I tried to push away, stol...",offmychest,advice/story
4,"So, preface. I have a tiny Asian vagina. My bo...","Tiny asian vagina, tears after sex, halp?",sex,advice/story
...,...,...,...,...
67995,"Firstly, thanks. I think anyone can accomplish...",Almost the same. A lot cheaper weed. Worse qua...,trees,other
67996,That is just like real life ideal condition. C...,Why making your fantasy world as boring as you...,SteamTeamGreen,other
67997,I know I'm not the first one to suggest this n...,Stop using AT&T until they drop the cap.,self,other
67998,Went to Hopapalooza at the Alibi Room for Vanc...,Unlimited samples of exceptional beer. Heaven....,Homebrewing,other


In [None]:
%%time
def count_vocab(item):
  return len(set(item.lower().split(" ")))

def count_total_words(x):
  return len(x.lower().split(" "))

full_df['content_vocab'] = full_df['content'].map(count_vocab)
full_df['summary_vocab'] = full_df['summary'].map(count_vocab)

full_df['content_total_words'] = full_df['content'].map(count_total_words)
full_df['summary_total_words'] = full_df['summary'].map(count_total_words)

CPU times: user 3.57 s, sys: 16.4 ms, total: 3.59 s
Wall time: 3.59 s


In [None]:
x = np.mean(full_df['content_total_words'])
print(f"overall avg post length: {x:.2f}")

x = np.mean(full_df['summary_total_words'])
print(f"overall avg summary length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='advice/story']['content_total_words'])
print(f"\nadvice_story avg post length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='advice/story']['summary_total_words'])
print(f"advice_story avg summary length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='media/lifestyle/sports']['content_total_words'])
print(f"\nmedia_lifestyle_sports avg post length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='media/lifestyle/sports']['summary_total_words'])
print(f"media_lifestyle_sports avg summary length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='gaming']['content_total_words'])
print(f"\ngaming avg post length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='gaming']['summary_total_words'])
print(f"gaming avg summary length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='other']['content_total_words'])
print(f"\nother avg post length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='other']['summary_total_words'])
print(f"other avg summary length: {x:.2f}")

overall avg post length: 234.56
overall avg summary length: 22.44

advice_story avg post length: 278.16
advice_story avg summary length: 22.81

media_lifestyle_sports avg post length: 209.82
media_lifestyle_sports avg summary length: 21.09

gaming avg post length: 210.69
gaming avg summary length: 22.76

other avg post length: 239.56
other avg summary length: 23.12
