In [1]:
import os
os.environ["PLATFORM"] = "vertexai"
os.environ["GCLOUD_PROJECT_ID"] = "wmt-mtech-assortment-ml-prod"

In [2]:
os.chdir("../")

In [3]:
questions = [
    "Which Brand had a significant PoD growth?",
    "Which Brand had a significant PoD decline?",
    "Which Stores had a significant PoD growth and what are their sales lift?",
    "Which Stores had a significant PoD decline and what are their units lift?",
    "Which item has the most gain in sales?",
    "Which item has the most drop in units?",
    "Why did an item get deleted?",
    "Why did an item get added?",
    "Why did the facings decrease for an item?",
    "Why did the an item expand?",
    "How many linear inches are used for recommendation in a store?",
    "How much space is used in a store compared to current mod?",
    "What is % of PnH violations for a relay compared to current mod?",
    "How many PoDs have more DoS than current mod?",
    "How much incrementality is an item bringing in?",
    "An item contracted in which stores?"
]

In [4]:
results = {}

## knowledge-graph retrieval

In [5]:
from models_api.vectorize import vectorDB
from app.knowledge import get_relevant_chunks

In [6]:
chunks_db = vectorDB(name="chunks")

In [7]:
results["KGR"] = {
    prompt: get_relevant_chunks(prompt, chunks_db, top_n=3, max_depth=2, max_breadth=4) for prompt in questions
}

## document question-answering

In [5]:
import json
from app.knowledge import generate_relevant_chunks

In [11]:
results["DQA"] = {
    prompt: json.loads(generate_relevant_chunks(prompt, min_items=0, max_items=3)) for prompt in questions
}

## tabulate

In [11]:
import pandas as pd

In [12]:
table = []
for method, _results in results.items():
    for prompt, chunks in _results.items():
        for _chunk in chunks["context"]:
            table += [{**dict(prompt=prompt, method=method), **_chunk}]

In [13]:
table = pd.DataFrame(table).sort_values(by=["prompt","method","no."])

In [14]:
pd.concat([table.drop(["supporting"], axis=1), 
           table["supporting"].apply(pd.Series).rename(lambda _: f"supporting_{_}", axis=1).fillna("")], 
          axis=1)\
.to_csv("retrieval_comparison.csv")