In [1]:
from RCSYS_utils import *
from RCSYS_models import *
from utils import *

import os
os.environ['OPENAI_API_KEY'] = 'YOUR_API_KEY'

  from .autonotebook import tqdm as notebook_tqdm


### Benchmark Setup

In [2]:
data = torch.load('../processed_data/benchmark_all.pt')

In [3]:
# Set the seed for reproducibility
np.random.seed(42)

# Count the occurrences of each food_id in the edge_index
edge_index = data[('user', 'eats', 'food')].edge_index
food_ids, counts = edge_index[1].unique(return_counts=True)

# Filter out food_ids that have been consumed over 1000 times
frequent_food_ids = food_ids[counts > 1000].tolist()
mask = torch.tensor([food_id not in frequent_food_ids for food_id in edge_index[1]])

# Filter the edge index based on the mask
filtered_edge_index = edge_index[:, mask]

# Step 1: Filter edge_label_index based on tag sums
user_tags = data['user']['tags']
food_tags = data['food']['tags']

user_mask = user_tags.sum(dim=1) > 5
food_mask = food_tags.sum(dim=1) > 5

edge_label_index = data[('user', 'eats', 'food')].edge_index
user_indices = edge_label_index[0]
food_indices = edge_label_index[1]

filtered_edges = edge_label_index[:, user_mask[user_indices] & food_mask[food_indices]]

# Step 2: Combine user and food prompts into a DataFrame
user_ids = filtered_edges[0].numpy()
food_ids = filtered_edges[1].numpy()

user_prompts = [data['user']['prompt'][i] for i in user_ids]
food_prompts = [data['food']['prompt'][i] for i in food_ids]

user_tags = data['user']['tags'][user_ids].numpy()
food_tags = data['food']['tags'][food_ids].numpy()

combined_prompts = ["{} {}".format(up, fp) for up, fp in zip(user_prompts, food_prompts)]

# Tag names
tag_names = [
    'low calorie', 'high calorie', 'low protein', 'high protein',
    'low carb', 'high carb', 'low sugar', 'high sugar', 'low fiber',
    'high fiber', 'low saturated fat', 'high saturated fat',
    'low cholesterol', 'high cholesterol', 'low sodium', 'high sodium',
    'low calcium', 'high calcium', 'low phosphorus', 'high phosphorus',
    'low potassium', 'high potassium', 'low iron', 'high iron',
    'low folic acid', 'high folic acid', 'low vitamin c', 'high vitamin c',
    'low vitamin d', 'high vitamin d', 'low vitamin b12', 'high vitamin b12'
]

# Create matching and contradicting columns
matching_tags = []
contradicting_tags = []

for u_tags, f_tags in zip(user_tags, food_tags):
    matching = []
    contradicting = []
    for i in range(0, len(tag_names), 2):
        if u_tags[i] == f_tags[i] == 1:
            matching.append(tag_names[i])
        elif u_tags[i+1] == f_tags[i+1] == 1:
            matching.append(tag_names[i+1])
        elif u_tags[i] == 1 and f_tags[i+1] == 1:
            contradicting.append(tag_names[i+1])
        elif u_tags[i+1] == 1 and f_tags[i] == 1:
            contradicting.append(tag_names[i])
    matching_tags.append(matching)
    contradicting_tags.append(contradicting)

# Create ground_truth column
ground_truth = []
for matching, contradicting in zip(matching_tags, contradicting_tags):
    if len(matching) > len(contradicting):
        ground_truth.append(f"Yes, because the food is {', '.join(matching)}")
    else:
        ground_truth.append(f"No, because the food is {', '.join(contradicting)}")

# Create DataFrame
df = pd.DataFrame({
    'user_id': user_ids,
    'food_id': food_ids,
    'user_tags': list(user_tags),
    'food_tags': list(food_tags),
    'user_food_combined_prompts': combined_prompts,
    'matching_tags': matching_tags,
    'contradicting_tags': contradicting_tags,
    'ground_truth': ground_truth
})

# Step 3: Sample 500 examples where matching > contradicting and 500 where contradicting > matching
matching_gt_contradict_df = df[df['matching_tags'].apply(len) > df['contradicting_tags'].apply(len)]
contradicting_gt_matching_df = df[df['contradicting_tags'].apply(len) > df['matching_tags'].apply(len)]

sample_matching = matching_gt_contradict_df.sample(n=100, random_state=42)
sample_contradicting = contradicting_gt_matching_df.sample(n=100, random_state=42)

# Combine the samples
final_sample = pd.concat([sample_matching, sample_contradicting])

# Create prompt text columns
user_food_liked = []
food_considered_healthy = []

for index, row in final_sample.iterrows():
    # Find 5 other food prompts the user has liked
    user_id = row['user_id']
    user_food_indices = filtered_edges[1][filtered_edges[0] == user_id].numpy()
    other_food_prompts = [data['food']['prompt'][i] for i in user_food_indices if i != row['food_id']]
    other_food_prompts = [str(prompt) for prompt in other_food_prompts][:5]

    food_health_status = "healthy" if "Yes" in row['ground_truth'] else "unhealthy"
    liked_prompt = f"Keep in mind that We now know that this given food is {food_health_status} to the user! Analyze the reason of this food: {row['user_food_combined_prompts']}. This food is liked by the User {user_id}, User {user_id} also liked the following foods: {'; '.join(other_food_prompts)}"
    
    user_food_liked.append(liked_prompt)
    
    # Find 5 other food prompts considered healthy
    healthy_food_indices = [food_ids[idx] for idx, (m, c) in enumerate(zip(matching_tags, contradicting_tags)) if len(m) > len(c)]
    other_healthy_food_prompts = [data['food']['prompt'][i] for i in healthy_food_indices if i != row['food_id']]
    other_healthy_food_prompts = [str(prompt) for prompt in other_healthy_food_prompts][:5]
    healthy_prompt = f"{row['user_food_combined_prompts']} Other food choices considered healthy to this user include: {'; '.join(other_healthy_food_prompts)}. Please use the information to analyze the given food."
    
    food_considered_healthy.append(healthy_prompt)

final_sample['user_food_liked'] = user_food_liked
final_sample['food_considered_healthy'] = food_considered_healthy

# Create user health tags and new prompt column
new_prompt_column = []
user_health_tags = []

for index, row in final_sample.iterrows():
    # Extract user health tags
    user_health_tags_text = ', '.join([tag_names[i] for i in range(len(tag_names)) if row['user_tags'][i] == 1])
    user_health_tags.append(user_health_tags_text)

    # Create user health text
    user_health_text = f"Given user's health condition, please focus on the following nutrient aspect in the analysis: if the food is {user_health_tags_text}"

    # Sample 5 other foods' ground truth
    other_foods = final_sample.sample(n=5, random_state=42).reset_index()
    other_foods_ground_truth = []
    for i, other_row in other_foods.iterrows():
        food_health_status = "healthy" if "Yes" in other_row['ground_truth'] else "unhealthy"
        other_foods_ground_truth.append(f"Food {i+1} is considered {food_health_status}, because the food is {other_row['ground_truth'].split(', because the food is ')[-1]}")

    other_foods_text = "; ".join(other_foods_ground_truth)
    food_health_status = "healthy" if "Yes" in row['ground_truth'] else "unhealthy"

    # Combine into new prompt
    new_prompt = f"{user_health_text}. And here are 5 additional foods as additional information. {other_foods_text}. Important! Keep in mind that We now know that this given food is {food_health_status} to the user! Please use the information to analyze this food: {row['user_food_combined_prompts']}"
    new_prompt_column.append(new_prompt)

final_sample['new_prompt'] = new_prompt_column

In [27]:
final_sample['user_food_liked'].iloc[0]

'Yes, User Node 34469: The user information is as follows: female, age 22, White, household income level (the higher the better): 7, education status: Some college or AA degree. Food Node 92410310: The food description is: Soft drink, cola. This food belongs to the category: Soft drinks. The ingredients in this food are: Beverages, carbonated, cola, regular.. This food is liked by the User 1956, User 1956 also liked the following foods: Food Node 94000100: The food description is: Water, tap. This food belongs to the category: Tap water. The ingredients in this food are: Beverages, water, tap, drinking.; Food Node 57123000: The food description is: Cereal (General Mills Cheerios). This food belongs to the category: Ready-to-eat cereal, lower sugar (=<21.2g/100g). The ingredients in this food are: Cereals ready-to-eat, GENERAL MILLS, CHEERIOS, Vitamin C as ingredient, Cereals, oats, regular and quick, not fortified, dry, Sugars, granulated, Salt, table, iodized, Vitamin B composite in c

In [4]:
final_sample.to_csv('../processed_data/benchmark_reasoning.csv', index=False)

### Querying

In [4]:
from openai import OpenAI

import torch
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel
import bert_score
from nltk.translate.bleu_score import sentence_bleu
import time
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")


In [5]:
client = OpenAI()

base_text = """
    Act as a nutritionist, your task is to use your knowledge to judge if the given food should be considered a healthy option to the user given their profiles. 
    Important Note: You must strictly provide your answer in the following format without further explanations or other words:  <Yes or No>, because the food is: <high or low> in <nutrients>, <choose between high or low> in <nutrients>,…. (E.g. Yes, becuase the food is low in calories, low in sodium and high in protein).
    Here is the input: 
"""

In [6]:
def query_GPT(final_sample, context_column, output_path):
    prompted_text_list = []
    for i, row in final_sample.iterrows():
        prompted_text = base_text + row[context_column]
        prompted_text_list.append(prompted_text)

    def query(prompt):
        response = client.chat.completions.create(
          model="gpt-3.5-turbo",
          messages=[
            {"role": "system", "content": "Act as a nutritionist. Ananlyze if a given food is healthy to a user and why."},
            {"role": "user", "content": prompt}
          ]
        )
        return response.choices[0].message.content

    
    GPT_reasoning = []
    for prompt in tqdm(prompted_text_list):
        reasoning = query(prompt)
        time.sleep(10)
        GPT_reasoning.append(reasoning)

    # Create a DataFrame
    gpt_results_df = pd.DataFrame({'GPT_results': GPT_reasoning})

    # Save to CSV
    gpt_results_csv_path = '../processed_data/reasoning/' + output_path
    gpt_results_df.to_csv(gpt_results_csv_path, index=False)

    return GPT_reasoning


def evaluate(GPT_reasoning, df):
    # Sample ground truth and GPT-generated results
    ground_truths = df['ground_truth'].tolist()  # List of ground truth texts
    gpt_results = GPT_reasoning

    # Calculate BERT scores
    P, R, F1 = bert_score.score(gpt_results, ground_truths, lang="en", model_type="bert-base-uncased")
    bert_scores = F1.numpy()
    bert_score_mean = np.mean(bert_scores)
    bert_score_std = np.std(bert_scores)

    # Calculate BLEU scores
    gpt_scores = [sentence_bleu([gt.split()], gpt.split()) for gt, gpt in zip(ground_truths, gpt_results)]
    bleu_score_mean = np.mean(gpt_scores)
    bleu_score_std = np.std(gpt_scores)

    # Create DataFrame for results
    result_df = pd.DataFrame({
        'GPT_results': gpt_results,
        'BERT_score': bert_scores,
        # 'BLEU_score': gpt_scores
    })

    # Summary
    summary = {
        'BERT_score_mean': bert_score_mean,
        'BERT_score_std': bert_score_std,
        'BLEU_score_mean': bleu_score_mean,
        'BLEU_score_std': bleu_score_std
    }

    # Output results
    print(summary)

In [25]:
# raw_reasoning = pd.read_csv('../processed_data/reasoning/raw_GPT.csv')['GPT_results'].tolist()
# xrec_reasoning = pd.read_csv('../processed_data/reasoning/XRec_GPT.csv')['GPT_results'].tolist()
# llm2er_reasoning = pd.read_csv('../processed_data/reasoning/LLM2ER_GPT.csv')['GPT_results'].tolist()
# our_reasoning = pd.read_csv('../processed_data/reasoning/Ours_GPT.csv')['GPT_results'].tolist()

In [28]:
# Raw 
raw_reasoning = query_GPT(final_sample, context_column='user_food_combined_prompts', output_path='raw_GPT.csv')
evaluate(raw_reasoning, final_sample)



{'BERT_score_mean': 0.6903412, 'BERT_score_std': 0.04729, 'BLEU_score_mean': 0.24047353478215364, 'BLEU_score_std': 0.0883334715206406}


In [29]:
# Xrec
xrec_reasoning = query_GPT(final_sample, context_column='food_considered_healthy', output_path='XRec_GPT.csv')
evaluate(xrec_reasoning, final_sample)



{'BERT_score_mean': 0.6992225, 'BERT_score_std': 0.03730052, 'BLEU_score_mean': 0.25539204386224, 'BLEU_score_std': 0.0773054282350995}


In [30]:
# LLM2ER
llm2er_reasoning = query_GPT(final_sample, context_column='user_food_liked', output_path='LLM2ER_GPT.csv')
evaluate(llm2er_reasoning, final_sample)



{'BERT_score_mean': 0.6948765, 'BERT_score_std': 0.046850063, 'BLEU_score_mean': 0.24030639585726427, 'BLEU_score_std': 0.11347167381448237}


In [7]:
our_reasoning = query_GPT(final_sample, context_column='new_prompt', output_path='Ours_GPT.csv')
evaluate(our_reasoning, final_sample)

100%|██████████| 200/200 [35:19<00:00, 10.60s/it]


{'BERT_score_mean': 0.7494171, 'BERT_score_std': 0.075974554, 'BLEU_score_mean': 0.2868334327549, 'BLEU_score_std': 0.12916104869960582}
