# Compare metrics between the models

In [1]:
from IPython.display import clear_output

!pip install datasets transformers rouge_score rouge-score nltk
# rouge-score is the google version
!pip install pyarrow
!pip install -q sentencepiece

clear_output()

In [2]:
import os
import re
import time
from tqdm import tqdm
import pandas as pd
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt

# nlp stuff
import nltk
nltk.download('punkt')

# tf stuff
import tensorflow_datasets as tfds 
import tensorflow as tf
from transformers import PegasusTokenizer, TFPegasusForConditionalGeneration # pegasus
from transformers import BartTokenizer, TFBartForConditionalGeneration # bart

# pytorch dataset types
import datasets
from datasets.dataset_dict import DatasetDict
from datasets import Dataset, load_metric, load_dataset

# pytorch bart stuff
import torch
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer

clear_output()

In [3]:
%%time

# specify your path to the repo here:
repo_path = '/content/gdrive/MyDrive/w266/w266_reddit_summarization'

from google.colab import drive
drive.mount('/content/gdrive')

# baseline bart
baseline_preds = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/bart_baseline_preds.parquet'))

# finetuned bart
# so far i only have the "advice_story" subreddit category done
advice_preds_finetune = pd.read_parquet(os.path.join(repo_path, 'data/model_outputs/bart_preds/bart_preds_advice.parquet'))

Mounted at /content/gdrive
CPU times: user 2.24 s, sys: 378 ms, total: 2.62 s
Wall time: 22.6 s


In [4]:
# join them
baseline_preds.columns = ['content', 'y', 'yhat_baseline']
advice_preds_finetune.columns = ['content', 'y', 'yhat_finetune']
advice_preds = pd.merge(advice_preds_finetune, baseline_preds, on = ['content', 'y'], how='left')
advice_preds

Unnamed: 0,content,y,yhat_finetune,yhat_baseline
0,What? No. A few very large banks started giv...,bad banking policy and a lack of governmental ...,Banks gave out mortgages to people who had no ...,The BBC News website looks at what happened t...
1,I don't know- I was raised in and around Phila...,you have to know the questions before you can ...,I didn't see religious clothing outside of wor...,"I'm a Catholic, but I've never been a fan of ..."
2,"So I've known him for a few years now, but onl...",What should I do?? I really don't want my inse...,I'm not sure if he's just interested in me or ...,"I love him so much, he's even more interested..."
3,My girlfriend of almost 3 years doesn't want t...,girlfriend won't have sex any longer. Don't kn...,Girlfriend won't have sex anymore and I don't ...,In a series of letters from African journalis...
4,People judging what you order when going out t...,Don't be an asshole at meal time OOOOOOR on re...,I agree with you more than most people.,"I'm a vegan, I'm not a vegetarian, I don't ea..."
...,...,...,...,...
1941,I went to a private high school in a relativel...,"gave Harry Connick, Jr. and his wife an admiss...",Harry Connick asked me to play piano at his sc...,"In the US, it is not easy to get into a priva..."
1942,"This is a cross post from askreddit, but I wan...",I am learning valuable new skills and being pa...,I am a C-level executive in a company that has...,"In a post on the microblogging site Reddit, e..."
1943,looks bigger \n What? The foreskin is a double...,"5 Swings and 5 misses, and that's without goin...",Circumcision is a good thing.,A look at some of the more interesting snippe...
1944,I was driving down fairly empty 55mph side str...,I crashed into a giant black guy's pick up tru...,"Saw a pick-up truck steps a man, I thought I w...",A young girl in the US state of New Jersey ha...


In [5]:
# compute metrics
metric = load_metric("rouge")
baseline_metrics = metric.compute(predictions=advice_preds['yhat_baseline'].tolist(), references=advice_preds['y'].tolist())
finetune_metrics = metric.compute(predictions=advice_preds['yhat_finetune'].tolist(), references=advice_preds['y'].tolist())

print("Baseline:")
print(baseline_metrics)

print("\n\nBaseline:")
print(finetune_metrics)

Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

Baseline:
{'rouge1': AggregateScore(low=Score(precision=0.16831263613157998, recall=0.15869763297399342, fmeasure=0.14541169149656044), mid=Score(precision=0.17419289137267552, recall=0.16391570846084666, fmeasure=0.14968092341112565), high=Score(precision=0.18015059849469509, recall=0.16958857214009354, fmeasure=0.15410101718805244)), 'rouge2': AggregateScore(low=Score(precision=0.026013263385480545, recall=0.024172843732139043, fmeasure=0.021562440811585173), mid=Score(precision=0.028386253418740938, recall=0.026905636359783887, fmeasure=0.0234941758389646), high=Score(precision=0.03070843658187854, recall=0.02990644729920436, fmeasure=0.025308122311153213)), 'rougeL': AggregateScore(low=Score(precision=0.12701577125345306, recall=0.12572870400252112, fmeasure=0.11190512610084428), mid=Score(precision=0.13134255733338607, recall=0.13043130097726074, fmeasure=0.11538173182341588), high=Score(precision=0.13594662644313377, recall=0.13541061530123, fmeasure=0.118989839959064)), 'rougeLs

In [8]:
.27/.17
195/164
20/15

1.3333333333333333