In [1]:
import sys
sys.path.append('..')
import random

from tqdm import tqdm
import numpy as np
import pandas as pd
from finetune.auto_label import *

SEED = 93
np.random.seed(SEED)
random.seed(SEED)
tqdm.pandas()

In [9]:
API_KEY_PATH = "../key"
FINETUNED_PATH = "gs://scraped-news-article-data-null/summary_eval_shilong.xlsx"
OG_PATH = "gs://scraped-news-article-data-null/summary_eval_og_shilong.xlsx"
TEST_SET_PATH = "gs://scraped-news-article-data-null/test_predicted.parquet"
TEST_SET_RATED_PATH = "gs://scraped-news-article-data-null/test_predicted_rated.parquet"
TEMPERATURE = 0
MODEL_NAME = "gpt-3.5-turbo-16k"

In [3]:
with open(API_KEY_PATH, "r") as fp:
    API_KEY = fp.read().strip()

finetuned_sample = pd.read_excel(FINETUNED_PATH)
og_sample = pd.read_excel(OG_PATH)
combined_sample = pd.concat([finetuned_sample, og_sample]).reset_index(drop=True)
combined_sample.head()

Unnamed: 0,source,id,category,title,published,body,summary,summary_type,predicted,rating,thought
0,reuters,22071,Software company morphs into personal injury firm,Plaintiffs' BarLegal InnovationLegal IndustryL...,2022-06-27T22:46:00,summarize in bullet points: (Reuters) - A soft...,* Mighty Law says it will undercut personal in...,BULLETS,\n* A software company is opening its own law ...,4,The summary accurately captures the main point...
1,cnbc,8940,Personal Finance,College hopefuls have a new ultimate dream sch...,2023-03-15T13:11:19+00:00,summarize in bullet points: ## Massachusetts I...,"* This year, the school named by the highest n...",BULLETS,\n* A new survey of college-bound students and...,3,The summary accurately captures the main point...
2,reuters,50327,LNGFuel OilGasHydrogenRenewable Fuels,Spain's Cepsa to invest 3 bln euros in green h...,2022-12-01T15:03:00,"summarize in bullet points: CADIZ, Spain, Dec ...","* To produce 300,000 tons of green hydrogen pe...",BULLETS,\n* Cepsa to invest 3 bln euros in renewable e...,3,Thought process for the rating:\n- The summary...
3,reuters,22786,U.S. Markets,"Wall St falls after recent strong gains, Alpha...",2023-02-09T00:30:00,"summarize in bullet points: NEW YORK, Feb 8 (R...",* Alphabet shares dive after Google AI chatbot...,BULLETS,"\n* S&P 500 futures down 1.2%, yuan up 0.2%, w...",2,Thought process for the rating:\n- The summary...
4,cnbc,1864,Sports,Canadian billionaire Steve Apostolopoulos says...,2023-04-24T17:42:10+00:00,summarize in bullet points: Canadian billionai...,* Canadian businessman Steve Apostolopoulos sa...,BULLETS,\n* The saga over Washington's embattled footb...,4,The summary reflects the content of the text w...


In [4]:
from langchain.chat_models import ChatOpenAI
import chromadb

chroma = chromadb.EphemeralClient()
plan_llm = ChatOpenAI(openai_api_key=API_KEY, temperature=0,
                          model_name=MODEL_NAME)


def rate_summary(row, selector):
    text = RatableText(context_text=row["body"], output=row["predicted"])
    try:
        result = evaluate_text(plan_llm, text, config.LABEL_SUMMARY_SYSTEM, config.LABEL_SUMMARY_USER, meal_selector=selector)
    except ValueError as e:
        result = RatingOutput(rating=np.nan, thought=str(e)).model_dump()
    return result

In [5]:
sample_ratings = []
for idx, row in tqdm(combined_sample.iterrows(), total=len(combined_sample.index)):
    subsample = combined_sample.drop(idx)
    collection = chroma.get_or_create_collection("temp-examples")
    collection.query(query_texts=["warm-up"], n_results=1)
    records = create_db_entries(subsample, collection)
    collection.add(**records)
    assert len(records["ids"]) == collection.count()
    meal_selector = ChromaRatingExampleSelector(collection)
    sample_ratings.append(
        rate_summary(row, meal_selector)["rating"]
    )
    chroma.delete_collection("temp-examples")
combined_sample["predicted_rating"] = sample_ratings
combined_sample[["rating", "predicted_rating"]].head()

100%|██████████| 20/20 [01:24<00:00,  4.22s/it]


Unnamed: 0,rating,predicted_rating
0,4,3.0
1,3,3.0
2,3,3.0
3,2,1.0
4,4,2.0


In [6]:
combined_sample["rating_diff"] = combined_sample.rating - combined_sample.predicted_rating
print(np.mean(combined_sample.rating_diff.abs()))
print(combined_sample.rating_diff.describe())

0.85
count    20.000000
mean      0.350000
std       1.225819
min      -2.000000
25%       0.000000
50%       0.000000
75%       1.000000
max       3.000000
Name: rating_diff, dtype: float64


In [7]:
test_set = pd.read_parquet(TEST_SET_PATH)
test_set.head()

Unnamed: 0,source,id,category,title,published,body,summary,summary_type,predicted
0,reuters,112266,China,Millions tested in Shanghai as China grapples ...,2022-07-07T21:01:00,"summarize in bullet points: SHANGHAI/BEIJING, ...",* Third day of mass testing in many Shanghai d...,BULLETS,\n* China reported 338 new COVID cases for Wed...
1,reuters,66543,Business,Ecigarettes and heated tobacco light up Imperi...,2022-05-17T08:37:00,"summarize in bullet points: LONDON, May 17 (Re...",* H1 adjusted net revenue up 0.3% at constant ...,BULLETS,"\n* Shares up more than 7% in morning trade, b..."
2,cnbc,7668,Sustainable Future,Private jet flights in Europe soar to record l...,2023-03-30T05:24:48+00:00,summarize in bullet points: A private jet is l...,* Analysis published Thursday by environmental...,BULLETS,\n* The number of private jet flights in Europ...
3,reuters,87382,Middle East,Israel's Netanyahu returns with hard-right cab...,2022-12-29T19:28:00,"summarize in bullet points: JERUSALEM, Dec 29 ...","* New government has nationalist, religious pa...",BULLETS,\n* Netanyahu's hard-right cabinet includes re...
4,reuters,15571,DiversityCorporate CounselCorporate Governance,Activist behind Harvard race case takes aim at...,2021-07-13T21:56:00,summarize in bullet points: (Reuters) - Edward...,* Lawsuit claims statutes discriminate against...,BULLETS,\n* The lawsuit seeks to invalidate laws requi...


In [8]:
collection = chroma.get_or_create_collection("temp-examples")
collection.query(query_texts=["warm-up"], n_results=1)
records = create_db_entries(combined_sample, collection)
collection.add(**records)
assert len(records["ids"]) == collection.count()
meal_selector = ChromaRatingExampleSelector(collection)

ratings = test_set.progress_apply(lambda row: rate_summary(row, meal_selector), axis=1).tolist()
ratings[0]

100%|██████████| 4527/4527 [4:43:33<00:00,  3.76s/it]  


{'rating': 5.0,
 'thought': 'The summary accurately captures the main points of the text. It mentions that China reported 338 new COVID cases, with Shanghai reporting 54 new cases. It also highlights the concerns about the impact of the outbreaks on the Chinese economy and the potential for restrictions that could disrupt global supply chains and trade. These are the key details that would be relevant to a retail investor.',
 'raw': 'Thought process:\nThe summary accurately captures the main points of the text. It mentions that China reported 338 new COVID cases, with Shanghai reporting 54 new cases. It also highlights the concerns about the impact of the outbreaks on the Chinese economy and the potential for restrictions that could disrupt global supply chains and trade. These are the key details that would be relevant to a retail investor.\n\nFinal Rating:\n5'}

In [10]:
test_set["rating"] = [r["rating"] for r in ratings]
print(test_set.rating.describe())

count    4526.000000
mean        3.500000
std         1.494705
min         1.000000
25%         2.000000
50%         4.000000
75%         5.000000
max         5.000000
Name: rating, dtype: float64


In [11]:
test_set.to_parquet(TEST_SET_RATED_PATH, index=False)