In [56]:
import torch
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel
import bert_score
from nltk.translate.bleu_score import sentence_bleu
import time
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

In [57]:
from llamaapi import LlamaAPI
import json

# Replace 'Your_API_Token' with your actual API token
llama = LlamaAPI('YOUR API TOKEN')

In [58]:
def query_llama(llama_reasoning, final_sample, context_column, output_path):
    prompted_text_list = []
    for i, row in final_sample.iterrows():
        prompted_text = base_text + row[context_column]
        prompted_text_list.append(prompted_text)

    def query(prompt):
    # API Request JSON Cell
        api_request_json = {
          "model": "llama3-70b",
          "messages": [
            {"role": "system", "content": "Act as a nutritionist. Ananlyze if a given food is healthy to a user and why."},
            {"role": "user", "content": prompt},
          ]
        }
        # Make your request and handle the response
        response = llama.run(api_request_json)
        return response.json()['choices'][0]['message']['content']

    
    for prompt in tqdm(prompted_text_list):
        try:
            reasoning = query(prompt)
        except:
            reasoning = 'API Error'
            print('API Error')
        time.sleep(10)
        llama_reasoning.append(reasoning)

    # Create a DataFrame
    gpt_results_df = pd.DataFrame({'GPT_results': llama_reasoning})

    # Save to CSV
    gpt_results_csv_path = '../processed_data/reasoning/' + output_path
    gpt_results_df.to_csv(gpt_results_csv_path, index=False)

    return llama_reasoning


def evaluate(reasoning, df):
    # Sample ground truth and GPT-generated results
    ground_truths = df['ground_truth'].tolist()  # List of ground truth texts
    gpt_results = reasoning

    # Calculate BERT scores
    P, R, F1 = bert_score.score(gpt_results, ground_truths, lang="en", model_type="bert-base-uncased")
    bert_scores = F1.numpy()
    bert_score_mean = np.mean(bert_scores)
    bert_score_std = np.std(bert_scores)

    # Calculate BLEU scores
    gpt_scores = [sentence_bleu([gt.split()], gpt.split()) for gt, gpt in zip(ground_truths, gpt_results)]
    bleu_score_mean = np.mean(gpt_scores)
    bleu_score_std = np.std(gpt_scores)

    # Create DataFrame for results
    result_df = pd.DataFrame({
        'GPT_results': gpt_results,
        'BERT_score': bert_scores,
        # 'BLEU_score': gpt_scores
    })

    # Summary
    summary = {
        'BERT_score_mean': bert_score_mean,
        'BERT_score_std': bert_score_std,
        'BLEU_score_mean': bleu_score_mean,
        'BLEU_score_std': bleu_score_std
    }
    # Output results
    print(summary)

In [59]:
base_text = """
    Act as a nutritionist, your task is to use your knowledge to judge if the given food should be considered a healthy option to the user given their profiles. 
    Important Note: You must STRICTLY provide your answer in the following format WITHOUT any further explanations or other words:  <Yes or No>, because the food is: <high or low> in <nutrients>, <choose between high or low> in <nutrients>,…. (E.g. Yes, becuase the food is low in calories, low in sodium and high in protein).
    Here is the input: 
"""
final_sample = pd.read_csv('../processed_data/benchmark_reasoning.csv')

In [60]:
# Raw 
llama_reasoning = []
raw_reasoning = query_llama(llama_reasoning, final_sample, context_column='user_food_combined_prompts', output_path='raw_llama3.csv')
evaluate(raw_reasoning, final_sample)

100%|██████████| 200/200 [41:16<00:00, 12.38s/it]


{'BERT_score_mean': 0.69880235, 'BERT_score_std': 0.10955263, 'BLEU_score_mean': 0.2621309652455708, 'BLEU_score_std': 0.09954877984334148}


In [61]:
# Xrec
xrec_reasoning = []
xrec_reasoning = query_llama(xrec_reasoning, final_sample, context_column='food_considered_healthy', output_path='XRec_llama3.csv')
evaluate(xrec_reasoning, final_sample)

100%|██████████| 200/200 [40:25<00:00, 12.13s/it]


{'BERT_score_mean': 0.70065427, 'BERT_score_std': 0.10495191, 'BLEU_score_mean': 0.2069737978992119, 'BLEU_score_std': 0.12186578613605656}


In [62]:
# LLM2ER
llm2er_reasoning = []
llm2er_reasoning = query_llama(llm2er_reasoning, final_sample, context_column='user_food_liked', output_path='LLM2ER_llama3.csv')
evaluate(llm2er_reasoning, final_sample)

100%|██████████| 200/200 [40:24<00:00, 12.12s/it]


{'BERT_score_mean': 0.7427185, 'BERT_score_std': 0.03158122, 'BLEU_score_mean': 0.21532947428836777, 'BLEU_score_std': 0.09428363458290431}


In [63]:
our_reasoning = []
our_reasoning = query_llama(our_reasoning, final_sample, context_column='new_prompt', output_path='Ours_llama3.csv')
evaluate(our_reasoning, final_sample)

100%|██████████| 200/200 [40:34<00:00, 12.17s/it]


{'BERT_score_mean': 0.84127486, 'BERT_score_std': 0.07312379, 'BLEU_score_mean': 0.3698287272917877, 'BLEU_score_std': 0.19553086488228746}


In [64]:
raw_reasoning = pd.read_csv('../processed_data/reasoning/raw_llama2.csv')['GPT_results'].tolist()
xrec_reasoning = pd.read_csv('../processed_data/reasoning/XRec_llama2.csv')['GPT_results'].tolist()
llm2er_reasoning = pd.read_csv('../processed_data/reasoning/LLM2ER_llama2.csv')['GPT_results'].tolist()
our_reasoning = pd.read_csv('../processed_data/reasoning/Ours_llama2.csv')['GPT_results'].tolist()

In [65]:
evaluate(raw_reasoning, final_sample)



{'BERT_score_mean': 0.5581448, 'BERT_score_std': 0.064839534, 'BLEU_score_mean': 0.05527894270450269, 'BLEU_score_std': 0.07889836575127629}


In [66]:
evaluate(xrec_reasoning, final_sample)



{'BERT_score_mean': 0.53682, 'BERT_score_std': 0.048586834, 'BLEU_score_mean': 0.01013901214911789, 'BLEU_score_std': 0.02342373600577563}


In [67]:
evaluate(llm2er_reasoning, final_sample)



{'BERT_score_mean': 0.5560213, 'BERT_score_std': 0.04305631, 'BLEU_score_mean': 0.015683256726849937, 'BLEU_score_std': 0.028369411627361955}


In [68]:
evaluate(our_reasoning, final_sample)



{'BERT_score_mean': 0.56366456, 'BERT_score_std': 0.058120113, 'BLEU_score_mean': 0.014524161710214391, 'BLEU_score_std': 0.02932298963902566}
