In [32]:
from IPython.display import clear_output

!pip install datasets transformers rouge_score nltk
# !pip install datasets transformers rouge-score nltk
# rouge-score is the google version
!pip install pyarrow
!pip install -q sentencepiece

clear_output()

In [33]:
# ! pip install --upgrade google-api-python-client google-auth-httplib2 google-auth-oauthlib
clear_output()

In [34]:
from IPython.display import clear_output
import os
import re
import time
from tqdm.notebook import trange, tqdm
import pandas as pd
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt

# nlp stuff
import nltk
nltk.download('punkt')

# tf stuff
import tensorflow_datasets as tfds 
import tensorflow as tf
from transformers import PegasusTokenizer, TFPegasusForConditionalGeneration # pegasus
from transformers import BartTokenizer, TFBartForConditionalGeneration # bart

# pytorch dataset types
import datasets
from datasets.dataset_dict import DatasetDict
from datasets import Dataset, load_metric, load_dataset

# pytorch bart stuff
import torch
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer

# clear_output()

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [35]:
# about this metric: https://huggingface.co/spaces/evaluate-metric/rouge
metric = load_metric("rouge")

In [41]:
# specify your path to the repo here:
repo_path = '/content/gdrive/MyDrive/w266/w266_reddit_summarization'

from google.colab import drive
drive.mount('/content/gdrive')

df = pd.read_csv(os.path.join(repo_path, 'data/manual_summaries.tsv'), sep='\t')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [42]:
df = df[df['manual_summary'] != 'na'].reset_index(drop=True)
df

Unnamed: 0,subreddit_group,content,y,yhat_baseline,yhat_bart_subreddit,yhat_bart_full,rouge1_precision,rouge1_recall,rouge1_fmeasure,rouge2_precision,rouge2_recall,rouge2_fmeasure,manual_summary
0,other,"Sleep deprivation has serious, serious bad con...",Don't feel guilty! Take care of you.,If you're struggling to get a good night's sl...,Do what you have to do to get some sleep.,Sleep deprivation is bad for health.,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,Sleep is very important for good health
1,media_lifestyle_sports,People claim hygiene is of a high enough stand...,"Nobody wants to be forever known as ""Bobby blo...",The Ebola outbreak in West Africa has left man...,"Ebio is probably unlikely, but it's probably not.","It's unlikely, but it's possible.",0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,People claim this gym is hygenic but given how...
2,other,"The top Google search match for ""2001 cinemagr...","The top Google search match for ""2001 cinemagr...","I've been writing for more than a decade, but...",I'm not an idiot.,I'm a dick.,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,Just post the true source for 2001 cinemagraph.
3,advice_story,The thing about dictionaries is that they are ...,"Human practice shapes language rules, books me...",D dictionaries are used by many people to reco...,D dictionaries are biased.,"Dictionaries are biased, but there's no defaul...",0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,Dictionaries are biased. They are a record of ...
4,media_lifestyle_sports,There is apparently no way for your toothbrush...,irrelvent rl;dr's I respect. Incorrect ones a...,"There is no such thing as a clean toothbrush,...",You're a dick.,You can't clean your toothbrush with a toothbr...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,There is apparently no way for your toothbrush...
...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,other,"So, I noticed a post on here the other day tha...",What weapons/armor do you carry on your charac...,I'm a fan of the video game World of Warcraft...,What equipment do you carry on your character?,What weapons do you carry on your character?,1.000000,0.615385,0.761905,0.857143,0.500000,0.631579,Looking for advice on weapons to use.
96,advice_story,So my boyfriend and I have been together just ...,Boyfriend still hasn't told me he loves me aft...,"I am a student at the University of Glasgow, ...",Boyfriend and I have been together for 1.5 yea...,Boyfriend of 1.5 years doesn't tell me that he...,0.409091,0.642857,0.500000,0.190476,0.307692,0.235294,My boyfriend of nearly 2 years hasn't told me ...
97,gaming,They all depend heavily on the current common ...,Sylvanas > TBK >= Cairne,"Sunwalkers, players and players all have diffe...","Sylanas, TBK, Cairne, Sylvanas, and TBK.","Sylanas, TBK, Cairne, and TBK are all good cards.",0.222222,0.666667,0.333333,0.125000,0.500000,0.200000,The success of these cards all rely on the cur...
98,advice_story,"I'm a little late to the party, but oh well. ...",my high school math teacher was a hypocritical...,I'm going to go to a party with my high schoo...,"Teacher is a bitch, and I don't know if anyone...",My math teacher was a bitch for three years in...,0.318182,0.700000,0.437500,0.190476,0.444444,0.266667,I brought a coke to my math class and my teach...


In [43]:
df['subreddit_group'].value_counts()

other                     25
media_lifestyle_sports    25
advice_story              25
gaming                    25
Name: subreddit_group, dtype: int64

In [19]:
def get_metrics_for_group(group):

  bart_baseline_metrics = metric.compute(predictions=df[df['subreddit_group']==group]['yhat_baseline'].tolist(), references=df[df['subreddit_group']==group]['y'].tolist())
  bart_full_metrics = metric.compute(predictions=df[df['subreddit_group']==group]['yhat_bart_full'].tolist(), references=df[df['subreddit_group']==group]['y'].tolist())
  bart_grouped_metrics = metric.compute(predictions=df[df['subreddit_group']==group]['yhat_bart_subreddit'].tolist(), references=df[df['subreddit_group']==group]['y'].tolist())
  manual_metrics = metric.compute(predictions=df[df['subreddit_group']==group]['manual_summary'].tolist(), references=df[df['subreddit_group']==group]['y'].tolist())

  result_dict = {'model': [], 'metric': [], 'precision': [], 'recall': [], 'fmeasure': []}

  for y in ['yhat_baseline', 'yhat_bart_full', 'yhat_bart_subreddit', 'manual_summary']:
      
    if y == 'yhat_baseline': 
      model_i = 'Baseline Model'
      metrics_i = bart_baseline_metrics
    elif y == 'yhat_bart_full':
      model_i = 'BART trained on full data'
      metrics_i = bart_full_metrics
    elif y == 'yhat_bart_subreddit':
      model_i = 'BART trained on subreddit groups'
      metrics_i = bart_grouped_metrics
    elif y == 'manual_summary':
      model_i = 'Manual summary'
      metrics_i = manual_metrics
    else:
      model_i = ''

    for m in ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']:

      result_dict['model'].append(model_i)
      result_dict['metric'].append(m)
      result_dict['precision'].append(metrics_i[m][1][0])
      result_dict['recall'].append(metrics_i[m][1][1])
      result_dict['fmeasure'].append(metrics_i[m][1][2])


  result_df = pd.DataFrame(result_dict)
  return result_df





In [20]:
df_out = get_metrics_for_group("advice_story")
df_out.sort_values(['metric', 'recall'])

Unnamed: 0,model,metric,precision,recall,fmeasure
0,Baseline Model,rouge1,0.16157,0.152304,0.145619
8,BART trained on subreddit groups,rouge1,0.280727,0.21359,0.217831
4,BART trained on full data,rouge1,0.296516,0.216286,0.219553
12,Manual summary,rouge1,0.262447,0.263831,0.242829
1,Baseline Model,rouge2,0.028981,0.028741,0.027833
5,BART trained on full data,rouge2,0.053522,0.039168,0.040019
9,BART trained on subreddit groups,rouge2,0.077736,0.059177,0.054813
13,Manual summary,rouge2,0.100799,0.121751,0.100857
2,Baseline Model,rougeL,0.128073,0.131438,0.120926
10,BART trained on subreddit groups,rougeL,0.22475,0.158531,0.16668


In [21]:
df_out = get_metrics_for_group("media_lifestyle_sports")
df_out.sort_values(['metric', 'recall'])

Unnamed: 0,model,metric,precision,recall,fmeasure
4,BART trained on full data,rouge1,0.192791,0.133095,0.136265
0,Baseline Model,rouge1,0.11251,0.145953,0.113721
8,BART trained on subreddit groups,rouge1,0.19117,0.149435,0.144431
12,Manual summary,rouge1,0.165647,0.171515,0.149715
9,BART trained on subreddit groups,rouge2,0.010302,0.005,0.00678
5,BART trained on full data,rouge2,0.019231,0.011154,0.013368
1,Baseline Model,rouge2,0.008855,0.01636,0.009675
13,Manual summary,rouge2,0.024483,0.033319,0.024206
2,Baseline Model,rougeL,0.07908,0.106039,0.082084
6,BART trained on full data,rougeL,0.159521,0.109691,0.112951


In [22]:
df_out = get_metrics_for_group("other")
df_out.sort_values(['metric', 'recall'])

Unnamed: 0,model,metric,precision,recall,fmeasure
8,BART trained on subreddit groups,rouge1,0.181516,0.1435,0.147772
4,BART trained on full data,rouge1,0.181072,0.144073,0.152346
0,Baseline Model,rouge1,0.118531,0.167509,0.12799
12,Manual summary,rouge1,0.1972,0.22727,0.180398
1,Baseline Model,rouge2,0.018828,0.033896,0.023592
9,BART trained on subreddit groups,rouge2,0.063997,0.044466,0.049776
5,BART trained on full data,rouge2,0.07399,0.051969,0.059034
13,Manual summary,rouge2,0.059496,0.061701,0.049623
10,BART trained on subreddit groups,rougeL,0.166585,0.128133,0.132018
6,BART trained on full data,rougeL,0.162321,0.129141,0.136456


In [23]:
df_out = get_metrics_for_group("gaming")
df_out.sort_values(['metric', 'recall'])

Unnamed: 0,model,metric,precision,recall,fmeasure
0,Baseline Model,rouge1,0.133651,0.155643,0.133667
12,Manual summary,rouge1,0.147022,0.213021,0.149659
4,BART trained on full data,rouge1,0.240801,0.219338,0.191647
8,BART trained on subreddit groups,rouge1,0.244478,0.240799,0.19504
1,Baseline Model,rouge2,0.009122,0.01777,0.011346
13,Manual summary,rouge2,0.023786,0.042025,0.029058
9,BART trained on subreddit groups,rouge2,0.050683,0.053835,0.042343
5,BART trained on full data,rouge2,0.05741,0.084872,0.053179
2,Baseline Model,rougeL,0.104663,0.124933,0.105531
14,Manual summary,rougeL,0.116127,0.172612,0.119057
