### Load the validation dataset from GCS

In [1]:
from google.colab import auth
auth.authenticate_user()

!gcloud config set project cnn-dailymail-387022

Updated property [core/project].


In [2]:
!pip install gcsfs
import gcsfs
fs = gcsfs.GCSFileSystem(project="cnn-dailymail-387022")

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
import pandas as pd

with fs.open("gs://cnn-dailymail-data/validation.csv") as fp:
  val_df = pd.read_csv(fp)

### Take the first 300 characters as output

In [4]:
predictions = {}

for _, row in val_df.iterrows():
  predictions[row.id] = row.article[:300]

In [5]:
# Write results to a csv file.
predictions_df = pd.DataFrame([
    {"id": k, "prediction": v}
    for k,v in predictions.items()
])
with fs.open("gs://cnn-dailymail-predictions/first-300-chars.csv", "w") as fp:
  predictions_df.to_csv(fp)

### Compute ROUGE scores

In [6]:
!pip install rouge-score
from rouge_score import rouge_scorer

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [7]:
# Load the csv file back.

import pandas as pd

with fs.open("gs://cnn-dailymail-predictions/first-300-chars.csv") as fp:
  df = pd.read_csv(fp)

preds = {}
for _, row in df.iterrows():
  preds[row.id] = row.prediction

In [8]:
import numpy as np
from tqdm import tqdm

scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"])

rouge1_scores = []
rouge2_scores = []
rougeL_scores = []

for _, row in tqdm(val_df.iterrows(), total=val_df.shape[0]):
  target = row.highlights
  prediction = preds[row.id]
  score = scorer.score(target=target, prediction=prediction)
  rouge1_scores.append(score["rouge1"].fmeasure)
  rouge2_scores.append(score["rouge2"].fmeasure)
  rougeL_scores.append(score["rougeL"].fmeasure)

print(f'''
rouge 1: {np.average(rouge1_scores)}
rouge 2: {np.average(rouge2_scores)}
rouge L: {np.average(rougeL_scores)}
''')

100%|██████████| 13368/13368 [00:24<00:00, 541.40it/s]


rouge 1: 0.3739955026578642
rouge 2: 0.15504765049190278
rouge L: 0.23351822585273302




