In [22]:
import os
from pprint import pprint
from dotenv import load_dotenv
from langchain_core.runnables import RunnableParallel, RunnableLambda
from langchain.sql_database import SQLDatabase
from langchain.chat_models import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from db.comp_food_database import engine
from db.comp_food_database import SessionLocal
from sqlalchemy import text
from rags.protein_amount.retrieval import (
    get_restaurant_or_brand_prompt,
    get_query,
    food_items_prompt
)
from rags.protein_amount.generation import generation_prompt

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
environment = os.getenv('ENVIRONMENT', 'development')
load_dotenv(f'.env.{environment}')

True

In [24]:
conn = SessionLocal()

In [25]:
llm4o = ChatOpenAI(model="gpt-4o", temperature=0.0,
                 openai_api_key=os.getenv('OPENAI_API_KEY'))

llm4omini1 = ChatOpenAI(model="gpt-4o-mini", temperature=0.0,
                 openai_api_key=os.getenv('OPENAI_API_KEY'))
llm4omini2 = ChatOpenAI(model="gpt-4o-mini", temperature=0.0,
                 openai_api_key=os.getenv('OPENAI_API_KEY'))

db = SQLDatabase(
    engine,
    include_tables=[
        'restaurant_menu_foods',
        'restaurant_menu_foods_fts',
        'branded_foods',
        'branded_foods_fts',
        'non_branded_foods',
        'non_branded_foods_fts'
])

In [26]:
input_text = 'Mac and cheese'

retrieval_chain = RunnableParallel(
        restaurant_or_brand= get_restaurant_or_brand_prompt | llm4omini1,
        food_items= food_items_prompt | llm4omini2
    ) | RunnableLambda(get_query)

query, retrieval_text = retrieval_chain.invoke({'text': input_text})
result = conn.execute(text(query))
data = result.fetchall()
data = [tuple(result.keys()), *data] if len(data) > 0 else []

In [27]:
pprint(data)
print(query)
print(retrieval_text)

[('description', 'serving_size', 'protein_amount', 'rank'),
 ('Macaroni, vegetable, enriched, dry', 84.0, 11.04, -6.342855388814356),
 ('Macaroni, vegetable, enriched, cooked', 134.0, 6.07, -6.342855388814356),
 ("CRACKER BARREL, macaroni n' cheese", 175.0, 11.36, -5.994720289488606),
 ('Babyfood, macaroni and cheese, toddler', 113.0, 3.96, -5.994720289488606),
 ('Macaroni and Cheese, canned entree', 244.0, 8.25, -5.994720289488606),
 ('Macaroni and Cheese, canned, microwavable', 213.0, 12.74, -5.994720289488606),
 ('Macaroni and cheese, frozen entree', 137.0, 7.67, -5.994720289488606),
 ('Macaroni or noodles with cheese', 230.0, 19.96, -5.994720289488606),
 ('Pasta mix, classic cheeseburger macaroni, unprepared', 123.0, 14.27, -5.682812426652694),
 ('Babyfood, dinner, macaroni and cheese, junior', 28.35, 0.74, -5.682812426652694),
 ('Babyfood, dinner, macaroni and cheese, strained', 28.35, 0.89, -5.682812426652694),
 ('Beef and macaroni with cheese sauce', 246.0, 22.44, -5.68281242665

In [13]:

generation_chain = generation_prompt | llm4omini2 | StrOutputParser()
result = generation_chain.invoke({'data': data, 'text': retrieval_text, 'original_input': input_text})

print(result)

```json
[
    {
        "food_item": "Macaroni",
        "protein_amount": "6.07",
        "protein_unit": "grams"
    },
    {
        "food_item": "Cheese",
        "protein_amount": "3.38",
        "protein_unit": "grams"
    }
]
```
