# Compare metrics between the models

In [None]:
from IPython.display import clear_output

!pip install datasets transformers rouge_score nltk
# !pip install datasets transformers rouge-score nltk
# rouge-score is the google version
!pip install pyarrow
!pip install -q sentencepiece

clear_output()

In [None]:
from IPython.display import clear_output
import os
import re
import time
from tqdm.notebook import trange, tqdm
import pandas as pd
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt

# nlp stuff
import nltk
nltk.download('punkt')

# tf stuff
import tensorflow_datasets as tfds 
import tensorflow as tf
from transformers import PegasusTokenizer, TFPegasusForConditionalGeneration # pegasus
from transformers import BartTokenizer, TFBartForConditionalGeneration # bart

# pytorch dataset types
import datasets
from datasets.dataset_dict import DatasetDict
from datasets import Dataset, load_metric, load_dataset

# pytorch bart stuff
import torch
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer

# clear_output()

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
# sign into huggingface
# from huggingface_hub import notebook_login
# notebook_login()

In [None]:
# !apt install git-lfs

In [None]:
# about this metric: https://huggingface.co/spaces/evaluate-metric/rouge
metric = load_metric("rouge")

Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

In [None]:
# specify your path to the repo here:
repo_path = '/content/gdrive/MyDrive/w266/w266_reddit_summarization'

from google.colab import drive
drive.mount('/content/gdrive')

# baseline bart
baseline_preds = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_baseline_preds.parquet'))
baseline_preds.columns = ['content', 'y', 'yhat_baseline']
baseline_preds = baseline_preds.sort_values(['content', 'y'])
baseline_preds.reset_index(drop=True, inplace=True)

# finetuned genre specific
df1 = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_preds_advice.parquet'))
df1['subreddit_group'] = 'advice_story'
df2 = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_preds_media.parquet'))
df2['subreddit_group'] = 'media_lifestyle_sports'
df3 = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_preds_gaming.parquet'))
df3['subreddit_group'] = 'gaming'
df4 = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_preds_other.parquet'))
df4['subreddit_group'] = 'other'

finetuned_preds = pd.concat([df1, df2, df3, df4], ignore_index=True)
finetuned_preds.columns = ['content', 'y', 'yhat_bart_subreddit', 'subreddit_group']
finetuned_preds = finetuned_preds.sort_values(['content', 'y'])
finetuned_preds.reset_index(drop=True, inplace=True)

# finetuned full (all genres)
'bart_full_preds_pt1.parquet'
finetuned_full = pd.concat([
  pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_full_preds_pt1.parquet')),
  pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_full_preds_pt2.parquet')),
  pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_full_preds_pt3.parquet')),
  pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_full_preds_pt4.parquet'))
])
finetuned_full = finetuned_full.sort_values(['content', 'y']).reset_index(drop=True)
finetuned_full.columns = ['content', 'y', 'yhat_bart_full']

Mounted at /content/gdrive


In [None]:
df = pd.merge(baseline_preds, finetuned_preds[['yhat_bart_subreddit', 'subreddit_group']], left_index=True, right_index=True, how='inner')
df = pd.merge(df, finetuned_full[['yhat_bart_full']], left_index=True, right_index=True, how='inner')
df = df[['subreddit_group', 'content', 'y', 'yhat_baseline', 'yhat_bart_subreddit', 'yhat_bart_full']]
df

Unnamed: 0,subreddit_group,content,y,yhat_baseline,yhat_bart_subreddit,yhat_bart_full
0,other,01/21[LUMPIA]-Oxide/PVP/Non-Craftable C4/Half ...,Lumpia is a type of spring roll. \n Edit: Form...,Here's a look at the most recent news from th...,LUMPIA! \n \n Server: \n Opide/PVP/Half Craft/...,LUMPIA! Welcome to Rust Community!
1,gaming,1 Can be given as a gift to someone in a court...,be creative and these things are awesome.,"As part of the BBC's Shrink Item season, I've...","1 is awesome, 2 is great, 3 is great. 2 is awe...","1, 2, 3, 4, 5 good for reasons, 6, 7, 10, 14, ..."
2,gaming,1) Doublelift =/= Rekkles. Double was a solid ...,You know nothing but act like you do. I'm done...,"In the summer of 2014, I wrote a series of le...","Rekkles was a solid player, but not a supersta...","Rekkles wasn't considered a ""superstar player""..."
3,advice_story,1) I have been told that estrogen level tests ...,don't read too much into the readings you've g...,If you are going to have a new blood test on ...,"Don't worry about getting a new patch on, just...",Don't worry about it.
4,advice_story,1) I was going to add a humorous comment in ad...,Adding urine frequently to a compost system ha...,"I am a bit of a fan of the word ""peeing on th...",Adding urine to the compost is not the best idea.,"Add urine to the compost, it will be unpleasan..."
...,...,...,...,...,...,...
3995,gaming,"you need a team that is coordinated, thats why...",You need to heavily engage on Kha'Zhix so he w...,"I'm a big fan of kha, but I'm not a fan of th...",Kha'Zhix is broken now and needs a coordinated...,Kha'Zhix is broken and needs to be more coordi...
3996,advice_story,"you should never be eating alone"" I agree with...",Personal time is often severely underrated.,I'm not a fan of the adage 'you should never ...,you should never be eating alone.,I agree with almost everything you've said.
3997,gaming,"you wont be able to do much, may experience sh...","yeah you can use it, but try to get the recomm...",If you're trying to use your computer to do s...,"if you want to run your pc, it's recommended w...",you wont be able to do much
3998,advice_story,you'd be hard pressed to prove that reproducin...,"Sometimes, having no impact on the commercial ...",If you've been in a copyright infringement di...,You can't prove that reproducing copyrighted w...,you're wrong.


# Metrics

In [None]:
def get_metrics_for_group(group):

  bart_baseline_metrics = metric.compute(predictions=df[df['subreddit_group']==group]['yhat_baseline'].tolist(), references=df[df['subreddit_group']==group]['y'].tolist())
  bart_full_metrics = metric.compute(predictions=df[df['subreddit_group']==group]['yhat_bart_full'].tolist(), references=df[df['subreddit_group']==group]['y'].tolist())
  bart_grouped_metrics = metric.compute(predictions=df[df['subreddit_group']==group]['yhat_bart_subreddit'].tolist(), references=df[df['subreddit_group']==group]['y'].tolist())

  result_dict = {'model': [], 'metric': [], 'precision': [], 'recall': [], 'fmeasure': []}

  for y in ['yhat_baseline', 'yhat_bart_full', 'yhat_bart_subreddit']:
      
    if y == 'yhat_baseline': 
      model_i = 'Baseline Model'
      metrics_i = bart_baseline_metrics
    elif y == 'yhat_bart_full':
      model_i = 'BART trained on full data'
      metrics_i = bart_full_metrics
    elif y == 'yhat_bart_subreddit':
      model_i = 'BART trained on subreddit groups'
      metrics_i = bart_grouped_metrics
    else:
      model_i = ''

    for m in ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']:

      result_dict['model'].append(model_i)
      result_dict['metric'].append(m)
      result_dict['precision'].append(metrics_i[m][1][0])
      result_dict['recall'].append(metrics_i[m][1][1])
      result_dict['fmeasure'].append(metrics_i[m][1][2])


  result_df = pd.DataFrame(result_dict)
  return result_df



In [None]:
df_out = get_metrics_for_group("advice_story")
df_out.sort_values(['metric', 'recall'])
# bart full is the best here

Unnamed: 0,model,metric,precision,recall,fmeasure
0,Baseline Model,rouge1,0.169182,0.167656,0.148635
8,BART trained on subreddit groups,rouge1,0.268871,0.199846,0.203121
4,BART trained on full data,rouge1,0.269722,0.202436,0.20549
1,Baseline Model,rouge2,0.024719,0.025488,0.022039
9,BART trained on subreddit groups,rouge2,0.073519,0.056831,0.055697
5,BART trained on full data,rouge2,0.080019,0.059387,0.059819
2,Baseline Model,rougeL,0.128318,0.135003,0.115511
10,BART trained on subreddit groups,rougeL,0.215872,0.164085,0.164783
6,BART trained on full data,rougeL,0.219045,0.166508,0.167548
3,Baseline Model,rougeLsum,0.131983,0.136487,0.117461


In [None]:
df_out = get_metrics_for_group("media_lifestyle_sports")
df_out.sort_values(['metric', 'recall'])
# bart full is the best here (but baseline does best for rouge1)

Unnamed: 0,model,metric,precision,recall,fmeasure
8,BART trained on subreddit groups,rouge1,0.230202,0.139792,0.151601
4,BART trained on full data,rouge1,0.222868,0.143014,0.150424
0,Baseline Model,rouge1,0.14555,0.148259,0.127118
1,Baseline Model,rouge2,0.018619,0.021305,0.016727
9,BART trained on subreddit groups,rouge2,0.057942,0.031208,0.035155
5,BART trained on full data,rouge2,0.058074,0.035373,0.037407
10,BART trained on subreddit groups,rougeL,0.192378,0.117863,0.126592
2,Baseline Model,rougeL,0.112189,0.121094,0.100465
6,BART trained on full data,rougeL,0.187731,0.121861,0.127126
11,BART trained on subreddit groups,rougeLsum,0.195276,0.119062,0.128027


In [None]:
df_out = get_metrics_for_group("gaming")
df_out.sort_values(['metric', 'recall'])
# bart full is the best

Unnamed: 0,model,metric,precision,recall,fmeasure
0,Baseline Model,rouge1,0.166484,0.152679,0.140723
8,BART trained on subreddit groups,rouge1,0.252087,0.155419,0.166379
4,BART trained on full data,rouge1,0.254808,0.164656,0.174469
1,Baseline Model,rouge2,0.020813,0.019154,0.017753
9,BART trained on subreddit groups,rouge2,0.060869,0.036829,0.038915
5,BART trained on full data,rouge2,0.061246,0.040805,0.041727
2,Baseline Model,rougeL,0.123719,0.119968,0.107449
10,BART trained on subreddit groups,rougeL,0.205861,0.127489,0.135569
6,BART trained on full data,rougeL,0.208121,0.134956,0.141622
3,Baseline Model,rougeLsum,0.128696,0.122163,0.110212


In [None]:
df_out = get_metrics_for_group("other")
df_out.sort_values(['metric', 'recall'])
# subreddit bart wins here!

Unnamed: 0,model,metric,precision,recall,fmeasure
0,Baseline Model,rouge1,0.160364,0.149492,0.136373
4,BART trained on full data,rouge1,0.246992,0.166482,0.171508
8,BART trained on subreddit groups,rouge1,0.245763,0.16805,0.173338
1,Baseline Model,rouge2,0.021279,0.020206,0.018182
5,BART trained on full data,rouge2,0.063425,0.039606,0.042193
9,BART trained on subreddit groups,rouge2,0.062601,0.043322,0.044441
2,Baseline Model,rougeL,0.122193,0.12103,0.106384
6,BART trained on full data,rougeL,0.204561,0.139062,0.14209
10,BART trained on subreddit groups,rougeL,0.204423,0.141117,0.143851
3,Baseline Model,rougeLsum,0.126551,0.123191,0.109157


BART Baseline & 14.9 & 2.0 & 12.1 \\
BART Full & 16.6 & 4.0 & 13.9 \\
BART Subreddit Split & 16.8 & 4.3 & 14.1 \\

# Word counts in each subgroup

In [None]:
full_df = pd.concat([
  pd.read_parquet(os.path.join(repo_path, 'data/reddit_parquet/train_test_split_v2/reddit_train.parquet')), 
  pd.read_parquet(os.path.join(repo_path, 'data/reddit_parquet/train_test_split_v2/reddit_test.parquet')), 
  pd.read_parquet(os.path.join(repo_path, 'data/reddit_parquet/train_test_split_v2/reddit_validation.parquet'))], 
  ignore_index=True)

full_df

Unnamed: 0,content,summary,subreddit,subreddit_group
0,Can bars and/or clubs legally charge men for e...,Is it legal for bars and/or clubs to charge me...,legaladvice,advice/story
1,Great Link! Enjoyed the comments at the end es...,check out the link,WTF,advice/story
2,"Met my bf on DeviantArt, oddly enough. I comme...",since there was of course WAY more involved in...,SRSQuestions,advice/story
3,You told me you really liked me. \n When I mov...,"You seduced me when I tried to push away, stol...",offmychest,advice/story
4,"So, preface. I have a tiny Asian vagina. My bo...","Tiny asian vagina, tears after sex, halp?",sex,advice/story
...,...,...,...,...
67995,"Firstly, thanks. I think anyone can accomplish...",Almost the same. A lot cheaper weed. Worse qua...,trees,other
67996,That is just like real life ideal condition. C...,Why making your fantasy world as boring as you...,SteamTeamGreen,other
67997,I know I'm not the first one to suggest this n...,Stop using AT&T until they drop the cap.,self,other
67998,Went to Hopapalooza at the Alibi Room for Vanc...,Unlimited samples of exceptional beer. Heaven....,Homebrewing,other


In [None]:
%%time
def count_vocab(item):
  return len(set(item.lower().split(" ")))

def count_total_words(x):
  return len(x.lower().split(" "))

full_df['content_vocab'] = full_df['content'].map(count_vocab)
full_df['summary_vocab'] = full_df['summary'].map(count_vocab)

full_df['content_total_words'] = full_df['content'].map(count_total_words)
full_df['summary_total_words'] = full_df['summary'].map(count_total_words)

CPU times: user 3.57 s, sys: 16.4 ms, total: 3.59 s
Wall time: 3.59 s


In [None]:
x = np.mean(full_df['content_total_words'])
print(f"overall avg post length: {x:.2f}")

x = np.mean(full_df['summary_total_words'])
print(f"overall avg summary length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='advice/story']['content_total_words'])
print(f"\nadvice_story avg post length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='advice/story']['summary_total_words'])
print(f"advice_story avg summary length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='media/lifestyle/sports']['content_total_words'])
print(f"\nmedia_lifestyle_sports avg post length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='media/lifestyle/sports']['summary_total_words'])
print(f"media_lifestyle_sports avg summary length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='gaming']['content_total_words'])
print(f"\ngaming avg post length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='gaming']['summary_total_words'])
print(f"gaming avg summary length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='other']['content_total_words'])
print(f"\nother avg post length: {x:.2f}")

x = np.mean(full_df[full_df['subreddit_group']=='other']['summary_total_words'])
print(f"other avg summary length: {x:.2f}")

overall avg post length: 234.56
overall avg summary length: 22.44

advice_story avg post length: 278.16
advice_story avg summary length: 22.81

media_lifestyle_sports avg post length: 209.82
media_lifestyle_sports avg summary length: 21.09

gaming avg post length: 210.69
gaming avg summary length: 22.76

other avg post length: 239.56
other avg summary length: 23.12


# Error Analysis

In [None]:
# baseline
output_dict = {'rouge1_precision': [], 'rouge1_recall': [], 'rouge1_fmeasure': [], 'rouge2_precision': [], 'rouge2_recall': [], 'rouge2_fmeasure': []}

for i in tqdm(range(df.shape[0])):
  output = metric.compute(predictions = [df['yhat_baseline'][i]], references = [df['y'][i]])
  output_dict['rouge1_precision'].append(output['rouge1'][1][0])
  output_dict['rouge1_recall'].append(output['rouge1'][1][1])
  output_dict['rouge1_fmeasure'].append(output['rouge1'][1][2])

  output_dict['rouge2_precision'].append(output['rouge2'][1][0])
  output_dict['rouge2_recall'].append(output['rouge2'][1][1])
  output_dict['rouge2_fmeasure'].append(output['rouge2'][1][2])

output_df = pd.DataFrame(output_dict)
output_df

  0%|          | 0/4000 [00:00<?, ?it/s]

Unnamed: 0,rouge1_precision,rouge1_recall,rouge1_fmeasure,rouge2_precision,rouge2_recall,rouge2_fmeasure
0,0.142857,0.222222,0.173913,0.000000,0.000000,0.000000
1,0.038462,0.142857,0.060606,0.000000,0.000000,0.000000
2,0.200000,0.065217,0.098361,0.000000,0.000000,0.000000
3,0.150000,0.187500,0.166667,0.000000,0.000000,0.000000
4,0.260870,0.150000,0.190476,0.000000,0.000000,0.000000
...,...,...,...,...,...,...
3995,0.263158,0.161290,0.200000,0.000000,0.000000,0.000000
3996,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
3997,0.181818,0.307692,0.228571,0.000000,0.000000,0.000000
3998,0.333333,0.218182,0.263736,0.085714,0.055556,0.067416


In [None]:
# bart full
output_dict = {'rouge1_precision': [], 'rouge1_recall': [], 'rouge1_fmeasure': [], 'rouge2_precision': [], 'rouge2_recall': [], 'rouge2_fmeasure': []}

for i in tqdm(range(df.shape[0])):
  output = metric.compute(predictions = [df['yhat_bart_full'][i]], references = [df['y'][i]])
  output_dict['rouge1_precision'].append(output['rouge1'][1][0])
  output_dict['rouge1_recall'].append(output['rouge1'][1][1])
  output_dict['rouge1_fmeasure'].append(output['rouge1'][1][2])

  output_dict['rouge2_precision'].append(output['rouge2'][1][0])
  output_dict['rouge2_recall'].append(output['rouge2'][1][1])
  output_dict['rouge2_fmeasure'].append(output['rouge2'][1][2])

output_df_bartfull = pd.DataFrame(output_dict)
output_df_bartfull

  0%|          | 0/4000 [00:00<?, ?it/s]

Unnamed: 0,rouge1_precision,rouge1_recall,rouge1_fmeasure,rouge2_precision,rouge2_recall,rouge2_fmeasure
0,0.200000,0.111111,0.142857,0.000000,0.000000,0.000000
1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
2,0.090909,0.021739,0.035088,0.000000,0.000000,0.000000
3,0.600000,0.187500,0.285714,0.250000,0.066667,0.105263
4,0.538462,0.175000,0.264151,0.000000,0.000000,0.000000
...,...,...,...,...,...,...
3995,0.400000,0.129032,0.195122,0.111111,0.033333,0.051282
3996,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
3997,0.285714,0.153846,0.200000,0.000000,0.000000,0.000000
3998,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [None]:
df2 = pd.merge(df, output_df_bartfull, left_index=True, right_index=True, how='inner')
df2 = df2.sort_values('rouge1_recall')
df2

Unnamed: 0,subreddit_group,content,y,yhat_baseline,yhat_bart_subreddit,yhat_bart_full,rouge1_precision,rouge1_recall,rouge1_fmeasure,rouge2_precision,rouge2_recall,rouge2_fmeasure
1623,advice_story,I work for Nielsen TV Ratings and we are in pe...,roach encrusted house and the old man from fam...,I've been living in a house in the US state o...,"Roaches ate mice, roaches ate dead mice and ro...",Roaches are disgusting.,0.000000,0.0,0.000000,0.000000,0.0,0.000000
2883,advice_story,"So about five months ago, this girl moved to m...","Having trouble getting with a friend of mine, ...","I'm a 14-year-old girl who loves me so much, ...",How can I get a girl to like me?,How can I get this girl to like me?,0.000000,0.0,0.000000,0.000000,0.0,0.000000
840,other,Hi everyone! \n I'm working on a Monte Carlo s...,Python 2: Can anyone give me any suggestions a...,"I'm a bit of a bit behind the wheel, but I'm ...",I need to print a 2D for loop dartboard for '<...,I need help making a dartboard for a 2D for lo...,0.000000,0.0,0.000000,0.000000,0.0,0.000000
2865,gaming,So I'm normally a guy that lets things roll of...,Entitled people in BoAs are levelling cancer.,"I'm a fan of the game, and I'm not a player o...","Tanks pull relentlessly, no matter what we sai...","Tanks rock, don't let them.",0.000000,0.0,0.000000,0.000000,0.0,0.000000
2861,advice_story,"So I wasn't at work this time, but waiting in ...","Lady tries to cut line, waved sanitary items i...",I'm a regular at work and I'm used to getting...,Snowball lady gives a stink eye to a guy in fr...,I'm an asshole and I'm not sorry.,0.000000,0.0,0.000000,0.000000,0.0,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...
1603,advice_story,I was working as a temp in a corporate office....,Bomb threat,I was working at a company in the US state of...,I ran into a bomb tech who called in a bomb th...,"Guy was a bomb tech who called in bomb threat,...",0.064516,1.0,0.121212,0.033333,1.0,0.064516
1237,other,"I have used many printers, ranging from many d...",Apple should make their own printer.,"Apple should make its own printer, according ...",Apple should make their own printer.,Apple should make their own printer.,1.000000,1.0,1.000000,1.000000,1.0,1.000000
2718,media_lifestyle_sports,ROFL. NGO Monitor? An Israeli website that cal...,you're an idiot.,What do you know about the Israeli-Palestinia...,ROFL.,You're an idiot.,1.000000,1.0,1.000000,1.000000,1.0,1.000000
3238,media_lifestyle_sports,The situation at Newcastle reminds me of the s...,Ashley is the problem.,Alan Pardew is a great manager but he is also ...,I don't know how much control Alan Pardew has ...,"Ashley is the problem, Cabaye is gone.",0.571429,1.0,0.727273,0.500000,1.0,0.666667


In [None]:
df2.groupby('subreddit_group')['rouge1_recall'].mean()

subreddit_group
advice_story              0.202180
gaming                    0.164526
media_lifestyle_sports    0.142903
other                     0.166616
Name: rouge1_recall, dtype: float64

In [None]:
#df2.to_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/preds_metrics.parquet')) 
data1 = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/preds_metrics.parquet'))
data1

In [None]:
df

Unnamed: 0,subreddit_group,content,y,yhat_baseline,yhat_bart_subreddit,yhat_bart_full
0,other,01/21[LUMPIA]-Oxide/PVP/Non-Craftable C4/Half ...,Lumpia is a type of spring roll. \n Edit: Form...,Here's a look at the most recent news from th...,LUMPIA! \n \n Server: \n Opide/PVP/Half Craft/...,LUMPIA! Welcome to Rust Community!
1,gaming,1 Can be given as a gift to someone in a court...,be creative and these things are awesome.,"As part of the BBC's Shrink Item season, I've...","1 is awesome, 2 is great, 3 is great. 2 is awe...","1, 2, 3, 4, 5 good for reasons, 6, 7, 10, 14, ..."
2,gaming,1) Doublelift =/= Rekkles. Double was a solid ...,You know nothing but act like you do. I'm done...,"In the summer of 2014, I wrote a series of le...","Rekkles was a solid player, but not a supersta...","Rekkles wasn't considered a ""superstar player""..."
3,advice_story,1) I have been told that estrogen level tests ...,don't read too much into the readings you've g...,If you are going to have a new blood test on ...,"Don't worry about getting a new patch on, just...",Don't worry about it.
4,advice_story,1) I was going to add a humorous comment in ad...,Adding urine frequently to a compost system ha...,"I am a bit of a fan of the word ""peeing on th...",Adding urine to the compost is not the best idea.,"Add urine to the compost, it will be unpleasan..."
...,...,...,...,...,...,...
3995,gaming,"you need a team that is coordinated, thats why...",You need to heavily engage on Kha'Zhix so he w...,"I'm a big fan of kha, but I'm not a fan of th...",Kha'Zhix is broken now and needs a coordinated...,Kha'Zhix is broken and needs to be more coordi...
3996,advice_story,"you should never be eating alone"" I agree with...",Personal time is often severely underrated.,I'm not a fan of the adage 'you should never ...,you should never be eating alone.,I agree with almost everything you've said.
3997,gaming,"you wont be able to do much, may experience sh...","yeah you can use it, but try to get the recomm...",If you're trying to use your computer to do s...,"if you want to run your pc, it's recommended w...",you wont be able to do much
3998,advice_story,you'd be hard pressed to prove that reproducin...,"Sometimes, having no impact on the commercial ...",If you've been in a copyright infringement di...,You can't prove that reproducing copyrighted w...,you're wrong.


In [None]:
# bart subreddit
output_dict = {'rouge1_precision': [], 'rouge1_recall': [], 'rouge1_fmeasure': [], 'rouge2_precision': [], 'rouge2_recall': [], 'rouge2_fmeasure': []}

for i in tqdm(range(df.shape[0])):
  output = metric.compute(predictions = [df['yhat_bart_subreddit'][i]], references = [df['y'][i]])
  output_dict['rouge1_precision'].append(output['rouge1'][1][0])
  output_dict['rouge1_recall'].append(output['rouge1'][1][1])
  output_dict['rouge1_fmeasure'].append(output['rouge1'][1][2])

  output_dict['rouge2_precision'].append(output['rouge2'][1][0])
  output_dict['rouge2_recall'].append(output['rouge2'][1][1])
  output_dict['rouge2_fmeasure'].append(output['rouge2'][1][2])

output_df_bart_subreddit = pd.DataFrame(output_dict)
output_df_bart_subreddit

  0%|          | 0/4000 [00:00<?, ?it/s]

Unnamed: 0,rouge1_precision,rouge1_recall,rouge1_fmeasure,rouge2_precision,rouge2_recall,rouge2_fmeasure
0,0.136364,0.333333,0.193548,0.000000,0.000000,0.000000
1,0.065217,0.428571,0.113208,0.000000,0.000000,0.000000
2,0.100000,0.021739,0.035714,0.000000,0.000000,0.000000
3,0.222222,0.250000,0.235294,0.058824,0.066667,0.062500
4,0.600000,0.150000,0.240000,0.111111,0.025641,0.041667
...,...,...,...,...,...,...
3995,0.400000,0.129032,0.195122,0.111111,0.033333,0.051282
3996,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
3997,0.363636,0.307692,0.333333,0.000000,0.000000,0.000000
3998,0.090909,0.018182,0.030303,0.000000,0.000000,0.000000


In [None]:
df2 = pd.merge(df, output_df_bart_subreddit, left_index=True, right_index=True, how='inner')
df2 = df2.sort_values('rouge1_recall')
# df2
df2.to_csv(os.path.join(repo_path, 'data/model_outputs/bart_preds/round2/bart_subreddit_metrics.csv'), index=False)

In [None]:
print(df2[['y', 'yhat_bart_full', 'rouge1_recall']].tail(1))

                             y  \
2497  Jews and a pound of weed   

                                         yhat_bart_full  rouge1_recall  
2497  I made out something about the jews and a poun...            1.0  


In [None]:
print("true:")
pprint(df2['y'][0])
print("\nPred")
pprint(df2['yhat_bart_full'][0])

true:
'Lumpia is a type of spring roll. \n Edit: Formatting'

Pred
'LUMPIA! Welcome to Rust Community!'


In [None]:
print("true:")
pprint(df2['y'][3999])
print("\nPred")
pprint(df2['yhat_bart_full'][3999])

true:
('reality proves you wrong so learn from it. \n'
 ' >if you are moving relative to a stopped car, you can easily pick it up '
 'with peripheral vision \n'
 ' tell that to the guy who just crashed into the vehicle in the video')

Pred
"you're an idiot."
