In [None]:
import os

from ace_of_splades.utils import get_openai_api_key

data_path = f"../data"

openai_key = get_openai_api_key()

In [None]:
from ace_of_splades.data import get_movies_dataset

movies = get_movies_dataset(local=True)

In [None]:
movies

In [None]:
from sentence_transformers import SentenceTransformer

encoder = SentenceTransformer("all-MiniLM-L6-v2")

In [None]:
import lancedb

uri = f"{data_path}/movies_embeddings"
db = lancedb.connect(uri)

In [None]:
movies_table = db.create_table("movies", movies, mode="overwrite")

In [None]:
def get_records(
    query, *, encoder=encoder, db_table=movies_table, max_results=10
):
    query_vector = encoder.encode(query).tolist()
    return db_table.search(
        query_vector
    ).limit(10).select(
        ['Release Year', 'Title', 'Origin/Ethnicity', 'Director', 'Cast', 'Genre', 'Plot', '_distance']
    ).to_list()


question = "What should I see tonight? I love Sci-Fi movies but I have seen most of the classics, such as Star Wars."

docs = get_records(question, max_results=5)
results = [doc for doc in docs]
results

In [None]:
GEEK_SYSTEM = """
  You are a DVD record store assistant and your goal is to recommed the user with a good movie to watch.

  You are a movie expert and a real geek: you love sci-fi movies and tend to get excited when you talk about them.
  Nevertheless, no matter what, you always want to make your customers happy.
"""

In [None]:
prompt_template = """
  Here are some suggested movies (ranked by relevance) to help you with your choice.
  {context}

  Use these suggestions to answer this question:
  {question}
"""

context_template = """
Title: {title}
Release date: {release_year}
Director: {director}
Cast: {cast}
Genre: {genre}
Overview: {plot}
"""


def format_records_into_context(records, *, template):
    return "".join(
        context_template.format(
            title=rec["Title"],
            release_year=rec["Release Year"],
            director=rec["Director"],
            cast=rec["Cast"],
            genre=rec["Genre"],
            plot=rec["Plot"],
        )
        for rec in results
    )

In [None]:
import openai

client = openai.OpenAI()


def ask(
    question,
    *,
    max_results=10,
    system=GEEK_SYSTEM,
    prompt_template=prompt_template,
    context_template=context_template,
    db_table=movies_table
):
    records = get_records(
        query=question, max_results=max_results, db_table=movies_table
    )
    context = format_records_into_context(records, template=context_template)

    prompt = prompt_template.format(question=question, context=context)

    chat_completion = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": prompt},
        ],
    )

    answer = chat_completion

    print(answer.choices[0].message.content)

    print(context)

    return answer


answer = ask(question=question)