### Set up PaLM

In [1]:
!pip install -q google-cloud-secret-manager
!pip install -q google-generativeai

In [2]:
from google.colab import auth
auth.authenticate_user()

In [3]:
import google.generativeai as palm

In [4]:
from google.cloud import secretmanager
client = secretmanager.SecretManagerServiceClient()
resource_name = "projects/708283936417/secrets/google-generative-ai-key/versions/1"
response = client.access_secret_version(request={"name": resource_name})
GOOGLE_GENERATIVE_AI_KEY = response.payload.data.decode('UTF-8')
palm.configure(api_key=GOOGLE_GENERATIVE_AI_KEY)

In [5]:
models = palm.list_models()

In [6]:
[m.name for m in models]

['models/chat-bison-001',
 'models/text-bison-001',
 'models/embedding-gecko-001']

### Load Dataset

In [7]:
!gsutil cp gs://cnn-dailymail-data/validation-100.csv .

Copying gs://cnn-dailymail-data/validation-100.csv...
/ [0 files][    0.0 B/412.0 KiB]                                                / [1 files][412.0 KiB/412.0 KiB]                                                
Operation completed over 1 objects/412.0 KiB.                                    


In [8]:
import pandas as pd
val_df = pd.read_csv("validation-100.csv")

### Run PaLM completion model

In [9]:
prompt_template = '''Here is a news article. Summarize this article into 3 or 4 bullet points. Use less than 75 words in total.

```{ARTICLE}```

CONCISE SUMMARY:'''

In [11]:
from tqdm import tqdm

predictions = {}
num_blocked = 0

for _, row in tqdm(val_df.iterrows(), total=len(val_df)):
  prompt = prompt_template.replace("{ARTICLE}", row["article"])
  completion = palm.generate_text(
      model="models/text-bison-001",
      prompt=prompt,
      temperature=0,
      max_output_tokens=800)
  if completion.result:
    predictions[row["id"]] = completion.result
  else:
    predictions[row["id"]] = "BLOCKED"
    num_blocked += 1

  0%|          | 0/100 [00:00<?, ?it/s]


ResourceExhausted: ignored

In [None]:
import pandas as pd

predictions_df = pd.DataFrame([
    {"id": k, "prediction": v}
    for k,v in predictions.items()
])

In [None]:
OUTPUT_FILE = "google-palm-text-bison-001.csv"
predictions_df.to_csv(OUTPUT_FILE, index=False)
!gsutil cp $OUTPUT_FILE gs://cnn-dailymail-predictions/

### Compute ROUGE scores

In [None]:
!pip install -q rouge-score

In [None]:
import numpy as np
from rouge_score import rouge_scorer
from tqdm import tqdm

scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"])

rouge1_scores = []
rouge2_scores = []
rougeL_scores = []

for _, row in tqdm(val_df.iterrows(), total=val_df.shape[0]):
  target = row.highlights
  prediction = predictions[row.id]
  if prediction == "BLOCKED":
    continue
  score = scorer.score(target=target, prediction=prediction)
  rouge1_scores.append(score["rouge1"].fmeasure)
  rouge2_scores.append(score["rouge2"].fmeasure)
  rougeL_scores.append(score["rougeL"].fmeasure)

print(f'''
rouge 1: {np.average(rouge1_scores)}
rouge 2: {np.average(rouge2_scores)}
rouge L: {np.average(rougeL_scores)}
''')