# Gradio App Simulation

In [1]:
from openai.embeddings_utils import get_embedding, distances_from_embeddings
import pandas as pd
import tiktoken
import openai
import ast
import os

In [2]:
openai.api_key = "" # YOUR API KEY
EMBEDDING_MODEL_NAME = "text-embedding-ada-002"
COMPLETION_MODEL_NAME = "text-davinci-003"
BATCH_SIZE = 128

In [3]:


df = pd.read_csv('data/embeddings.csv')
df.embeddings = df.embeddings.apply(ast.literal_eval)

In [4]:
def get_rows_sorted_by_relevance(question, df):
    """
    Function that takes in a question string and a dataframe containing
    rows of text and associated embeddings, and returns that dataframe
    sorted from least to most relevant for that question
    """
    
    # Get embeddings for the question text
    question_embeddings = get_embedding(question, engine=EMBEDDING_MODEL_NAME)
    
    # Make a copy of the dataframe and add a "distances" column containing
    # the cosine distances between each row's embeddings and the
    # embeddings of the question
    df_copy = df.copy()
    df_copy["distances"] = distances_from_embeddings(
        question_embeddings,
        df_copy["embeddings"].values,
        distance_metric="cosine"
    )
    
    # Sort the copied dataframe by the distances and return it
    # (shorter distance = more relevant so we sort in ascending order)
    df_copy.sort_values("distances", ascending=True, inplace=True)
    return df_copy

In [5]:
get_rows_sorted_by_relevance('hello my friend', df)

Unnamed: 0,text,embeddings,distances
546,borough: Staten Island\n\nneighborhood tabulat...,"[0.013599338941276073, -0.011916662566363811, ...",0.273225
37,borough: Staten Island\n\nneighborhood tabulat...,"[0.017743214964866638, -0.023271799087524414, ...",0.274539
434,borough: Brooklyn\n\nneighborhood tabulation a...,"[0.012627826072275639, -0.022713862359523773, ...",0.275098
81,borough: Bronx\n\nneighborhood tabulation area...,"[-0.0031556966714560986, -0.011681584641337395...",0.275897
514,borough: Manhattan\n\nneighborhood tabulation ...,"[0.010790945030748844, -0.029447169974446297, ...",0.276315
...,...,...,...
76,borough: Manhattan\n\nneighborhood tabulation ...,"[0.009699555113911629, -0.024010714143514633, ...",0.305439
25,borough: Brooklyn\n\nneighborhood tabulation a...,"[-0.011681396514177322, -0.03041571006178856, ...",0.306587
402,borough: Queens\n\nneighborhood tabulation are...,"[-0.008398461155593395, -0.029509568586945534,...",0.308329
303,borough: Queens\n\nneighborhood tabulation are...,"[0.0063171652145683765, -0.01053432747721672, ...",0.310527


In [6]:
def create_prompt(question, df, max_token_count):
    """
    Given a question and a dataframe containing rows of text and their
    embeddings, return a text prompt to send to a Completion model
    """
    # Create a tokenizer that is designed to align with our embeddings
    tokenizer = tiktoken.get_encoding("cl100k_base")
    
    # Count the number of tokens in the prompt template and question
    prompt_template = """
Answer the question based on the context below, and if the question
can't be answered based on the context, say "I don't know"

Context: 

{}

---

Question: {}
Answer:"""
    
    current_token_count = len(tokenizer.encode(prompt_template)) + \
                            len(tokenizer.encode(question))
    
    context = []
    for text in get_rows_sorted_by_relevance(question, df)["text"].values:
        
        # Increase the counter based on the number of tokens in this row
        text_token_count = len(tokenizer.encode(text))
        current_token_count += text_token_count
        
        # Add the row of text to the list if we haven't exceeded the max
        if current_token_count <= max_token_count:
            context.append(text)
        else:
            break

    return prompt_template.format("\n\n###\n\n".join(context), question)

In [7]:
print(create_prompt('hello my friend', df, max_token_count=1800))


Answer the question based on the context below, and if the question
can't be answered based on the context, say "I don't know"

Context: 

borough: Staten Island

neighborhood tabulation area name: St. George-New Brighton

food scrap dropoff site: H.E.A.L.T.H for Youths Skyline Community Garden

location: 1 Clyde Place

hosted by: H.E.A.L.T.H for Youths Skyline Community Garden

open months: Year Round

operation day hours: 24/7 (Start Time: 24/7 - End Time:  24/7)

website: www.health4youths.org

borough and community district: 501

nyc council district number: 49

latitude: 40.639901

longitude: -74.0902865

police precinct: 120

notes: Not accepted: meat, bones, or dairy

2010 cencus tract: 77.0

borough block lot: borough block lot missing

building identification number: building identification number missing

###

borough: Staten Island

neighborhood tabulation area name: West New Brighton-Silver Lake-Grymes Hill

food scrap dropoff site: Olivet Heavenly Harvest

location: 97 My

In [8]:
def answer_question(question, df, max_prompt_tokens=1800, max_answer_tokens=600):
    """
    Given a question, a dataframe containing rows of text, and a maximum
    number of desired tokens in the prompt and response, return the
    answer to the question according to an OpenAI Completion model
    
    If the model produces an error, return an empty string
    """
    
    prompt = create_prompt(question, df, max_prompt_tokens)
    
    try:
        response = openai.Completion.create(
            model=COMPLETION_MODEL_NAME,
            prompt=prompt,
            max_tokens=max_answer_tokens
        )
        return response["choices"][0]["text"].strip()
    except Exception as e:
        print(e)
        return ""

In [9]:
print(answer_question('Mention food scrap dropoffs near Queens that have open months year round.', df))

31 St between 31 Ave and Broadway, S/E Corner of 31 Ave and Crescent St, 33 St between Broadway and 31 Ave, NW Corner of Queens Plaza North & 21 Street, SE Corner of Crescent St & 30th Dr, SE Corner of 31st Ave & Crescent St, NE Corner of 35 Avenue & 12 Street


In [10]:
import openai
import gradio as gr


def predict(message, history):
    global df
    return answer_question(message, df)
gr.ChatInterface(predict).launch()

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


