In [3]:
#@title Installing required Python packages
!pip install boto3
!pip install mistralai
!pip install openai
!pip install google-search-results

Collecting openai
  Using cached openai-1.35.10-py3-none-any.whl (328 kB)
Installing collected packages: openai
Successfully installed openai-1.35.10


In [4]:
#@title Importing Python libraries and modules


import os
import re
import time
import datetime
import pytz
import dateutil
import requests
import json
import csv

import pandas as pd

from google.colab import files
import ipywidgets as widgets
from IPython.display import display

from openai import OpenAI
import tabulate
import textwrap

import boto3

from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

from serpapi import GoogleSearch

current_date = datetime.datetime.now(
        pytz.timezone("America/Los_Angeles")
    ).strftime("%B %d, %Y")

In [5]:
#@title API keys


# OpenAI's API key (sign up at https://platform.openai.com/signup to get $5 in
# free credit that can be used during your first 3 months)
openai_api_key = "ADD API KEY"  # @param {type:"string"}
openai_client = OpenAI(
    api_key=openai_api_key,
)

mistral_client = MistralClient(api_key="ADD API KEY")

llama2_client = boto3.client(service_name='bedrock-runtime', region_name='us-east-1',aws_access_key_id = 'ADD API KEY',
                          aws_secret_access_key= 'ADD API KEY')
# SerpApi's API key (sign up at https://serpapi.com/users/sign_up?plan=free for
# a free plan with 100 searches/month)
serpapi_api_key = "ADD API KEY"  # @param {type:"string"}

assert (
    openai_api_key is not None and openai_api_key != ""
), "OpenAI's API key is not set"
assert (
    serpapi_api_key is not None and serpapi_api_key != ""
), "SerpApi's API key is not set"


In [6]:
#@title Function calling for the base LLM


def call_llm_gpt_api(prompt, model, temperature, max_tokens, chat_completions=True):
  # See https://platform.openai.com/docs/guides/gpt for details
  if chat_completions:
    # Chat completions API
    response = openai_client.chat.completions.create(
        model=model,
        temperature=temperature,
        max_tokens=max_tokens,
        messages=[
            {
                "role": "system",
                "content": (
                    "You are a helpful assistant. Answer as concisely as"
                    f" possible. Knowledge cutoff: {current_date}."
                ),
            },
            {"role": "user", "content": "What's today's date?"},
            {
                "role": "assistant",
                "content": f"Today is {current_date} in Pacific Standard Time.",
            },
            {"role": "user", "content": prompt},
        ],
    )
    return response.choices[0].message.content

  else:
    # Completions API
    response = openai_client.completions.create(
        model=model,
        temperature=temperature,
        max_tokens=max_tokens,
        prompt=prompt,
    )
    return response.choices[0].text


In [7]:
def call_llm_mistral_api(prompt, model, temperature):
  # See https://platform.openai.com/docs/guides/gpt for details


  messages = f"""

    "You are a helpful assistant. Answer as concisely as possible. Knowledge cutoff: {current_date}. {prompt}"
   """

  message = [
        ChatMessage(role="user", content=messages)
    ]

  chat_completion = mistral_client.chat(
        model=model,
        temperature=temperature,
        messages=message
    )

  return chat_completion.choices[0].message.content.strip()

In [8]:
def call_llm_llama2_api(prompt, model, temperature):
  # See https://platform.openai.com/docs/guides/gpt for details


  messages = f"""

    "<s>[INST] <<SYS>> You are a helpful assistant. Answer as concisely as possible. Knowledge cutoff: {current_date}. {prompt}<</SYS>> [/INST]"
   """

  message = json.dumps({
    "prompt": messages,
    "temperature": temperature
    })

  accept = 'application/json'
  contentType = 'application/json'

  response = llama2_client.invoke_model(body=message, modelId=model, accept=accept, contentType=contentType)

  chat_completion = json.loads(response.get("body").read())

  return chat_completion.get("generation").strip()

In [9]:
#@title Function calling for the search engine


def call_search_engine(query):
  params = {
    "q": query,
    # "location": "California, United States",
    "hl": "en",
    "gl": "us",
    "google_domain": "google.com",
    "api_key": serpapi_api_key,

  }

  search = GoogleSearch(params)
  return search.get_dict()


In [10]:
#@title Utility functions for FreshPrompt


def is_date(string, fuzzy=False):
  # Parse a string into a date and check its validity
  try:
      dateutil.parser.parse(string, fuzzy=fuzzy)
      return True
  except ValueError:
      return False


def format_date(d):
  # Standardize the date format for each search result
  date = dateutil.parser.parse(current_date, fuzzy=True).strftime("%b %d, %Y")
  if d is None:
    return None

  for t in ["second", "minute", "hour"]:
    if f"{t} ago" in d or f"{t}s ago" in d:
      return date

  t = "day"
  if f"{t} ago" in d or f"{t}s ago" in d:
    n_days = int(re.search("(\d+) days? ago", d).group(1))
    return (
        datetime.datetime.strptime(date, "%b %d, %Y")
        - datetime.timedelta(days=n_days)
    ).strftime("%b %d, %Y")

  try:
    return dateutil.parser.parse(d, fuzzy=True).strftime("%b %d, %Y")
  except ValueError:
    for x in d.split():
      if is_date(x):
        return dateutil.parser.parse(x, fuzzy=True).strftime("%b %d, %Y")


def extract_source_webpage(link):
  # Extract source webpage
  return (
      link.strip()
      .replace("https://www.", "")
      .replace("http://www.", "")
      .replace("https://", "")
      .replace("http://", "")
      .split("/")[0]
  )


def simplify_displayed_link(displayed_link):
  # Simplify displayed link
  if displayed_link is None:
    return None
  return extract_source_webpage(displayed_link.split(' › ')[0])


def format_search_results(search_data, title_field=None, highlight_field=None):
  # Standardize search results as shown in Figure 3 (left) in the paper
  field = 'snippet_highlighted_words'
  if field in search_data and isinstance(search_data[field], list):
    search_data[field] = ' | '.join(search_data[field])

  field = 'displayed_link'
  if field in search_data:
    search_data[field] = simplify_displayed_link(search_data[field])

  # edge case 1
  if search_data.get('type') == 'local_time':
    source = search_data.get('displayed_link')
    date = format_date(search_data.get('date'))
    title = search_data.get('title')

    snippet = search_data.get('snippet')
    if snippet is None and 'result' in search_data:
      if 'extensions' in search_data and isinstance(
          search_data['extensions'], list
      ):
        snippet = '\n\t'.join(
            [search_data['result']] + search_data['extensions']
        )
      else:
        snippet = search_data['result']

    highlight = search_data.get('snippet_highlighted_words')
    if highlight is None and 'result' in search_data:
      highlight = search_data['result']

  # edge case 2
  elif 'type' in search_data and search_data['type'] == 'population_result':
    source = search_data.get('displayed_link')
    if source is None and 'sources' in search_data:
      if (
          isinstance(search_data['sources'], list)
          and 'link' in search_data['sources'][0]
      ):
        source = extract_source_webpage(search_data['sources'][0]['link'])

    date = format_date(search_data.get('date'))
    if date is None and 'year' in search_data:
      date = format_date(search_data['year'])

    title = search_data.get('title')

    snippet = search_data.get('snippet')
    if snippet is None and 'population' in search_data:
      if 'place' in search_data:
        snippet = '\n\t'.join(
            [
                f"{search_data['place']} / Population",
            ]
            + [
                search_data['population'],
            ]
        )
      else:
        snippet = search_data['population']

    highlight = search_data.get('snippet_highlighted_words')
    if highlight is None and 'population' in search_data:
      highlight = search_data['population']

  else:
    source = search_data.get('displayed_link')
    date = format_date(search_data.get('date'))
    title = (
        search_data.get('title')
        if title_field is None
        else search_data.get(title_field)
    )
    highlight = (
        search_data.get('snippet_highlighted_words')
        if highlight_field is None
        else search_data.get(highlight_field)
    )
    snippet = search_data.get('snippet', '')

    if 'rich_snippet' in search_data:
      for key in ['top', 'bottom']:
        if (
            key in search_data['rich_snippet']
            and 'extensions' in search_data['rich_snippet'][key]
        ):
          snippet = '\n\t'.join(
              [snippet] + search_data['rich_snippet'][key]['extensions']
          )

    if 'list' in search_data:
      assert isinstance(search_data['list'], list)
      snippet = '\n\t'.join([snippet] + search_data['list'])

    if 'contents' in search_data and 'table' in search_data['contents']:
      tbl = search_data['contents']['table']
      assert isinstance(tbl, list)
      snippet += '\n'
      for row in tbl:
        snippet += f'\n{",".join(row)}'

    if snippet is not None and snippet.strip() == '':
      snippet = None

  return {
      'source': source,
      'date': date,
      'title': title,
      'snippet': snippet,
      'highlight': highlight,
  }


def format_knowledge_graph(search_data):
  # Standardize knowledge graphs as shown in Figure 3 (left) in the paper
  source = None
  if "source" in search_data and "link" in search_data["source"]:
    source = extract_source_webpage(search_data["source"]["link"])

  date = None

  title = None
  if "title" in search_data:
    title = search_data["title"]
    if "type" in search_data:
      title += f"\n\t{search_data['type']}"

  snippet = ""
  for field in search_data:
    if (
        (field not in ["title", "type", "kgmid"])
        and ("_link" not in field)
        and ("_stick" not in field)
        and isinstance(search_data[field], str)
        and not search_data[field].startswith("http")
    ):
      snippet += f"\n\t{field}: {search_data[field]}"

  if snippet.strip() == "":
    snippet = None
  else:
    snippet = snippet.strip()

  highlight = None

  return {
      "source": source,
      "date": date,
      "title": title,
      "snippet": snippet,
      "highlight": highlight,
  }


def format_questions_and_answers(search_data):
  # Standardize questions and answers as shown in Figure 3 (left) in the paper
  source = None
  if "link" in search_data:
    source = extract_source_webpage(search_data["link"])

  date = None

  title = None
  if "question" in search_data:
    title = search_data["question"]

  snippet = None
  if "answer" in search_data:
    snippet = search_data["answer"]

  highlight = None

  return {
      "source": source,
      "date": date,
      "title": title,
      "snippet": snippet,
      "highlight": highlight,
  }


def freshprompt_format(
    question,
    search_data,
    reasoning_and_answer,
    num_organic_results,
    num_related_questions,
    num_questions_and_answers,
    num_retrieved_evidences,
):
  """Build FreshPrompt for each question

  Args:
    question: The question to process.
    search_data: Search data.
    reasoning_and_answer: The reasoning and answer.
    num_organic_results: Number of organic results to keep.
    num_related_questions: Number of related questions to keep.
    num_questions_and_answers: Number of questions and answers to keep.
    num_retrieved_evidences: Number of retrieved evidences to keep.

  Returns:
    A prompt that incorporates retrieved evidences for each question.
  """

  df = pd.DataFrame(columns=['source', 'date', 'title', 'snippet', 'highlight'])

  # Organic results
  organic_results = [None] * num_organic_results
  for k in range(num_organic_results):
    if (
        'organic_results' in search_data
        and len(search_data['organic_results']) > k
    ):
      organic_results[k] = format_search_results(
          search_data['organic_results'][k]
      )
    else:
      organic_results[k] = format_search_results({})

  for d in organic_results[::-1]:
    df = pd.concat([df, pd.DataFrame([d])], ignore_index=True)

  # Related questions
  related_questions = [None] * num_related_questions
  for k in range(num_related_questions):
    if (
        'related_questions' in search_data
        and len(search_data['related_questions']) > k
    ):
      related_questions[k] = format_search_results(
          search_data['related_questions'][k], title_field='question'
      )
    else:
      related_questions[k] = format_search_results({})

  for d in related_questions[::-1]:
    df = pd.concat([df, pd.DataFrame([d])], ignore_index=True)

  # Questions and Answers
  questions_and_answers = [None] * num_questions_and_answers
  for k in range(num_questions_and_answers):
    if (
        'questions_and_answers' in search_data
        and len(search_data['questions_and_answers']) > k
    ):
      questions_and_answers[k] = format_questions_and_answers(
          search_data['questions_and_answers'][k]
      )
    else:
      questions_and_answers[k] = format_questions_and_answers({})

  for d in questions_and_answers[::-1]:
    df = pd.concat([df, pd.DataFrame([d])], ignore_index=True)

  # Knowledge graph
  knowledge_graph = None
  if 'knowledge_graph' in search_data:
    knowledge_graph = format_knowledge_graph(search_data['knowledge_graph'])
  else:
    knowledge_graph = format_knowledge_graph({})
  df = pd.concat([df, pd.DataFrame([knowledge_graph])], ignore_index=True)

  # Answer box
  answer_box = None
  if 'answer_box' in search_data:
    answer_box = format_search_results(
        search_data['answer_box'], highlight_field='answer'
    )
  else:
    answer_box = format_search_results({})
  df = pd.concat([df, pd.DataFrame([answer_box])], ignore_index=True)

  # Sort by date
  df['date'] = df['date'].apply(lambda x: format_date(x))
  df['datetime'] = pd.to_datetime(df['date'], errors='coerce')
  df = df.sort_values(by='datetime', na_position='first')
  df.replace({pd.NaT: None}, inplace=True)
  df = df.dropna(how='all')

  # Select top_k supporting evidences overall
  evidences = []

  for _, row in df.tail(num_retrieved_evidences).iterrows():
    evidences.append(
        f"""\n\nsource: {row['source']}\ndate: {row['date']}\ntitle: {row['title']}\nsnippet: {row['snippet']}\nhighlight: {row['highlight']}"""
    )
  # print(type(question),type(search_data),type(reasoning_and_answer))
  return (
      ''.join(
          [
              f'\n\n\nquery: {question}',
          ]
          + evidences #list(search_data)
      )
      + f'\n\nquestion: {question}{reasoning_and_answer}'
  )


In [11]:
#@title Demonstration examples


demo_questions = [
    "What year is considered Albert Einstein's annus mirabilis?",
    "Which photographer took the most expensive photograph in the world?",
    "How many days are left until the 2023 Grammy Awards?",
    "How many years ago did the Boxing Day Tsunami happen?",
    (
        "When did Amazon become the first publicly traded company to exceed a"
        " market value of $3 trillion?"
    ),
]

concise_demo_reasonings_and_answers = [
    (
        "1905 is considered Albert Einstein's annus mirabilis, his miraculous"
        " year."
    ),
    (
        'The most expensive photograph in the world is "Le Violon d\'Ingres".'
        " The photograph was created by Man Ray."
    ),
    (
        "The 2023 Grammy Awards ceremony was held on February 5, 2023. Thus,"
        " the ceremony has already taken place."
    ),
    (
        "The disaster occurred on December 26, 2004. Thus, it happened 19 years"
        " ago."
    ),
    "Amazon's market capitalization has never exceeded $3 trillion.",
]

verbose_demo_reasonings_and_answers = [
    (
        "In the year of 1905, Albert Einstein published four groundbreaking"
        " papers that revolutionized scientific understanding of the universe."
        " Thus, scientists call 1905 Albert Einstein's annus mirabilis — his"
        " year of miracles."
    ),
    (
        "Man Ray's famed \"Le Violon d'Ingres\" became the most expensive"
        " photograph ever to sell at auction, sold for $12.4 million on May"
        " 14th, 2022 at Christie's New York. The black and white image, taken"
        " in 1924 by the American surrealist artist, transforms a woman's naked"
        " body into a violin by overlaying the picture of her back with"
        " f-holes. Thus, Man Ray is the photographer who took the most"
        " expensive photograph in the world."
    ),
    (
        "The 2023 Grammy Awards, officially known as the 65th Annual Grammy"
        " Awards ceremony, was held in Los Angeles on February 5, 2023. Thus,"
        " the event has already taken place."
    ),
    (
        "The Boxing Day Tsunami refers to the 2004 Indian Ocean earthquake and"
        " tsunami, which is one of the deadliest natural disasters in recorded"
        " history, killing an estimated 230,000 people across 14 countries. The"
        " disaster occurred on December 26, 2004, which is 19 years ago."
    ),
    (
        "Amazon's market capitalization hit a peak of roughly $1.9 trillion in"
        " July 2021. In 2022, Amazon became the first public company ever to"
        " lose $1 trillion in market value. Thus, Amazon's market value has"
        " never exceeded $3 trillion. In fact, Apple became the first publicly"
        " traded U.S. company to exceed a market value of $3 trillion in"
        " January 2022."
    ),
]

prefix = (
    f"\nanswer: As of today {current_date}, the most up-to-date and relevant"
    " information regarding this query is as follows. "
)

concise_demo_reasonings_and_answers = [
    prefix + x for x in concise_demo_reasonings_and_answers
]
verbose_demo_reasonings_and_answers = [
    prefix + x for x in verbose_demo_reasonings_and_answers
]



In [12]:
#@title Retrieving search data for demonstration examples


demo_search_data = [call_search_engine(q) for q in demo_questions]

In [22]:
#@title Function calling for FreshPrompt


def call_freshprompt(model, question, check_premise=False, verbose=False):
  temperature = 0.0
  max_tokens = 256
  chat_completions = True

  if model.startswith('gpt-4'):
    num_organic_results = 15
    num_related_questions = 3
    num_questions_and_answers = 3
    num_retrieved_evidences = 15
  else:
    num_organic_results = 15
    num_related_questions = 2
    num_questions_and_answers = 2
    num_retrieved_evidences = 5

  if verbose:
    demo_reasonings_and_answers = verbose_demo_reasonings_and_answers
  else:
    demo_reasonings_and_answers = concise_demo_reasonings_and_answers

  # Generate prompts for demo examples
  demo_prompts = []
  # for q, s, ra in zip(
  #     demo_questions, demo_search_data, concise_demo_reasonings_and_answers
  # ):
  #     demo_prompts.append(
  #     freshprompt_format(
  #         q,
  #         s,
  #         ra,
  #         num_organic_results,
  #         num_related_questions,
  #         num_questions_and_answers,
  #         num_retrieved_evidences,
  #     )
  #     )

  freshprompt_demo = ''.join(demo_prompts).strip()

  if check_premise:
    suffix = (
        "\nPlease check if the question contains a valid premise before"
        " answering.\nanswer: "
    )
  else:
    suffix = "\nanswer: "

  freshprompt_question = freshprompt_format(
      question,
      call_search_engine(question),
      suffix,
      num_organic_results,
      num_related_questions,
      num_questions_and_answers,
      num_retrieved_evidences,
  )
  # freshprompt_question = freshprompt_format(
  #     question,
  #     evidence,
  #     suffix,
  #     num_organic_results,
  #     num_related_questions,
  #     num_questions_and_answers,
  #     num_retrieved_evidences,
  # )

  fresh_prompt = freshprompt_demo + freshprompt_question
  # fresh_prompt = freshprompt_question

  if model == "gpt-3.5-turbo-1106":
    answer = call_llm_gpt_api(
        fresh_prompt, model, temperature, max_tokens, chat_completions
    )
  elif model == "mistral-small-latest":
    answer = call_llm_mistral_api(
        fresh_prompt, model, temperature
    )
  elif model == "meta.llama2-70b-chat-v1":
    answer = call_llm_llama2_api(
        fresh_prompt, model, temperature
    )
  return answer


In [14]:
def process_freshqa():
    # processing specific to freshqa dataset
    df_og = pd.read_csv("/content/freshqa.csv")
    new_header = df_og.iloc[1]  # Grab the second row for the new column names
    df = df_og.copy().loc[2:][:2] #the final indexing can be used to control how many/which questions to test
    df.columns = new_header  # Set the new column names
    query_list = df["question"].tolist()
    ans_list = df["answer_0"].tolist()
    ques_id_list = df["id"].tolist()
    effective_year_list = df["effective_year"].tolist()
    num_hops_list = df["num_hops"].tolist()
    fact_type_list = df["fact_type"].tolist()
    premise_list = df["false_premise"].tolist()
    return ques_id_list, query_list, ans_list, effective_year_list, num_hops_list, fact_type_list, premise_list

In [15]:
def process_QAQA():
    # processing specific to freshqa dataset
    df_og = pd.read_csv("/content/QAQA.csv")
    df = df_og.copy()[:2] # slice controls the subset of questions to be tested
    query_list = df["question"].tolist()
    ans_list = df["abstractive_answer"].tolist()
    ques_id_list = df["idx"].tolist()
    premise_list = df["all_assumptions_valid"].tolist()
    return ques_id_list, query_list, ans_list, premise_list

In [None]:
# Manually upload the dataset csv or through google drive
# from google.colab import drive
# drive.mount('/content/drive')

In [25]:
#@title FreshPrompt


# @markdown ---
dataset = "QAQA"
# model_name = "meta.llama2-70b-chat-v1"
model_name = "mistral-small-latest"
# model_name = "gpt-3.5-turbo-1106" #@param ["gpt-4-0125-preview", "gpt-4-turbo-preview", "gpt-4-1106-preview", "gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-instruct", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0301"]
check_premise = False  # @param {type:"boolean"}
# @markdown ### Ask your question here!

# question = "Who is the latest artist confirmed to be performing during the 2024 Grammys telecast?"  # @param {type:"string"}
# answer = call_freshprompt(model_name, question, check_premise=check_premise)
# button = widgets.Button(description="SHOW ANSWER")
# output = widgets.Output()

answer_list = []

if dataset == "freshqa":
  ques_id_list, query_list, ans_list, effective_year_list, num_hops_list, fact_type_list, premise_list = process_freshqa()
  for ques_id, question in zip(ques_id_list,query_list):
    print(ques_id)
    answer = call_freshprompt(model_name, question, check_premise=check_premise)
    answer_list.append(answer)

  qa_data_dict = {
                "ques_id":ques_id_list,
                "question":query_list,
                "true_ans":ans_list,
                "effective_year":effective_year_list,
                "num_hops":num_hops_list,
                "fact_type":fact_type_list,
                "premise":premise_list,
                "final_answer":answer_list
            }

elif dataset == "QAQA":
  ques_id_list, query_list, ans_list, premise_list = process_QAQA()
  for ques_id, question in zip(ques_id_list,query_list):
    print(ques_id)
    answer = call_freshprompt(model_name, question, check_premise=check_premise)
    answer_list.append(answer)

  qa_data_dict = {
                "ques_id":ques_id_list,
                "question":query_list,
                "true_ans":ans_list,
                "all_assumptions_valid": premise_list,
                "final_answer":answer_list
            }

df = pd.DataFrame(qa_data_dict)
print(df.head())
df.to_csv(dataset + model_name + "_response.csv",index=False)

# def on_button_clicked(b):
#   # Display the message within the output widget.
#   with output:
#     print(f'\n{answer}')


# button.on_click(on_button_clicked)
# display(button, output)

0
1
   ques_id                     question  \
0        0   what did pete burns die of   
1        1  what kind of fish is salmon   

                                            true_ans all_assumptions_valid  \
0  Pete Burns died following a sudden cardiac arr...             all_valid   
1  Salmon is a kind of ray-finned fish in the fam...             all_valid   

                                        final_answer  
0  Pete Burns, the lead singer of the band Dead o...  
1  Salmon is a type of ray-finned fish in the fam...  
