In [None]:
import os
import autogen
#from autogen import AssistantAgent, UserProxyAgent
#from autogen.code_utils import  infer_lang

from typing import Dict, List # Import Dict and List from typing module
import math


# Read data from file
file_path = 'restaurant-data.txt'
#file_path = os.path.join('data', 'external', 'restaurant-data.txt')

if not os.path.exists(file_path):
        raise FileNotFoundError(f"File '{file_path}' not found.")

# Define the functions for fetching data, analyzing reviews, and calculating scores

def load_restaurant_data(file_path: str) -> Dict[str, List[str]]:
    data = {}
    try:
        with open(file_path, 'r') as file:
          for line in file:
              if '. ' in line:  # Đảm bảo có dấu chấm và khoảng trắng để tách
              #name, review = line.strip().split('. ', 1)
                    name, review = line.split('. ', 1)
                    normalized_name = name.lower().strip()
                    if normalized_name in data:
                        data[normalized_name].append(review.strip())
                    else:
                        data[normalized_name] = [review.strip()]
    except FileNotFoundError:
        print(f"Error: File '{file_path}' not found.")
              
    return data

restaurant_data = load_restaurant_data(file_path)

In [None]:
def fetch_restaurant_data(restaurant_name: str) -> Dict[str, List[str]]:
    # Trả về đánh giá cho nhà hàng cụ thể
    reviews = restaurant_data.get(restaurant_name, [])
    if not reviews:
        print(f"No reviews found for {restaurant_name}.")
    return {restaurant_name: reviews}

def analyze_reviews(reviews: List[str]) -> Dict[str, List[int]]:
    food_scores = []
    customer_service_scores = []
    keywords = {
        "food": {
            "awful": 1, "horrible": 1, "disgusting": 1,
            "bad": 2, "unpleasant": 2, "offensive": 2,
            "average": 3, "uninspiring": 3, "forgettable": 3,
            "good": 4, "enjoyable": 4, "satisfying": 4,
            "awesome": 5, "incredible": 5, "amazing": 5
        },
        "customer_service": {
            "awful": 1, "horrible": 1, "disgusting": 1,
            "bad": 2, "unpleasant": 2, "offensive": 2,
            "average": 3, "uninspiring": 3, "forgettable": 3,
            "good": 4, "enjoyable": 4, "satisfying": 4,
            "awesome": 5, "incredible": 5, "amazing": 5
        }
    }
    
    for review in reviews:
        food_score, customer_service_score = None, None
        words = review.split()
        for word in words:
            if word.lower() in keywords["food"]:
                food_score = keywords["food"][word.lower()]
            if word.lower() in keywords["customer_service"]:
                customer_service_score = keywords["customer_service"][word.lower()]
        if food_score and customer_service_score:
            food_scores.append(food_score)
            customer_service_scores.append(customer_service_score)
    
    return {
        "food_scores": food_scores,
        "customer_service_scores": customer_service_scores
    }
def calculate_overall_score(restaurant_name: str, food_scores: List[int], customer_service_scores: List[int]) -> Dict[str, float]:
    n = len(food_scores)
    if n == 0:
        return {restaurant_name: 0.0}
    score_sum = sum(math.sqrt(food_scores[i]**2 * customer_service_scores[i]) for i in range(n))
    overall_score = (score_sum / (n * math.sqrt(125))) * 10
    return {restaurant_name: round(overall_score, 3)}

def normalize_name(query: str) -> str:
    if 'for' in query:
            # Handle 'for' based queries
            name_start = query.find('for') + 4
            name = query[name_start:].replace('?', '').strip()
    elif 'restaurant' in query:
            # Handle 'restaurant' based queries
            name_start = query.find('restaurant') + 11
            name = query[name_start:].replace('overall?', '').strip()
    return name.lower().strip()

def get_data_fetch_agent_prompt(restaurant_query: str) -> str:
    return f"Fetch reviews for the restaurant named {restaurant_query}."

def get_review_analysis_prompt(reviews: List[str]) -> str:
    return f"Analyze the following reviews and extract food and customer service scores:\n{reviews}"

def get_scoring_agent_prompt(food_scores: List[int], customer_service_scores: List[int]) -> str:
    return (
        "Given the following scores for food and customer service, calculate the overall score for the restaurant.\n\n"
        f"Food scores: {food_scores}\n"
        f"Customer service scores: {customer_service_scores}\n\n"
        "Use the provided scores to call the calculate_overall_score function and compute the overall score."
    )

In [None]:
#ham lay ten nha hang
user_query = "What is the overall score for Krispy Kreme?"
print(normalize_name(user_query))

In [None]:
normalized_query = normalize_name(user_query)
#lay review cua nha hang
review_M = fetch_restaurant_data(normalized_query)
review_M1 = list(review_M.values())[0] # Access the list of reviews
print(f"1. review M1: {review_M1}") # This will print the list of reviews for McDonald's

#lay prompt cua agent fectch
fetch_prompt = get_data_fetch_agent_prompt(normalized_query)
print(f"2. fetch_prompt  :{fetch_prompt}")

#Lay prompt analysis 
analysis_prompt = get_review_analysis_prompt(review_M1)
print(f"3. analysis_prompt  :{analysis_prompt}")
scores = analyze_reviews(review_M1)
print(f"4. scores: {scores}")
#lay prompt scoring
scores_prompt = get_scoring_agent_prompt(scores["food_scores"], scores["customer_service_scores"])
print(f"5. scores_prompt  :{scores_prompt}")

overall_score = calculate_overall_score(normalized_query, scores["food_scores"], scores["customer_service_scores"])
print(f"6. Overall score for {user_query}: {overall_score[normalized_query]}\n")

# CAC HAM DE CHAY DANH GIA


In [None]:
import sys, os
import re
import json
#from main import main 
from typing import List

class TerminalColors:
    GREEN = '\033[92m'
    RED = '\033[91m'
    RESET = '\033[0m'

def suppress_prints() -> None:
    sys.stdout = open(os.devnull, 'w')

def restore_prints() -> None:
    sys.stdout = sys.__stdout__

def contains_num_with_tolerance(text: str, pattern: float, tolerance: float=0) -> bool:
    # Note: the test will only match numbers that have 3 or more decimal places.
    nums = re.findall(r'\d*\.\d{3}', text)
    nums = [float(num) for num in nums]
    pattern_matches = [num for num in nums if abs(num - pattern) <= tolerance]
    return len(pattern_matches) >= 1

In [None]:
def main(user_query: str):
  normalized_query = normalize_name(user_query)
  #lay review cua nha hang
  review_M = fetch_restaurant_data(normalized_query)
  review_M1 = list(review_M.values())[0] # Access the list of reviews
  #print(f"1. review M1: {review_M1}") # This will print the list of reviews for McDonald's

  #lay prompt cua agent fectch
  fetch_prompt = get_data_fetch_agent_prompt(normalized_query)
  #print(f"2. fetch_prompt  :{fetch_prompt}")

  #Lay prompt analysis 
  analysis_prompt = get_review_analysis_prompt(review_M1)
  #print(f"3. analysis_prompt  :{analysis_prompt}")
  scores = analyze_reviews(review_M1)
  #print(f"4. scores: {scores}")
  #lay prompt scoring
  scores_prompt = get_scoring_agent_prompt(scores["food_scores"], scores["customer_service_scores"])
  #print(f"5. scores_prompt  :{scores_prompt}")

  overall_score = calculate_overall_score(normalized_query, scores["food_scores"], scores["customer_service_scores"])
  print(f"6. Overall score for {user_query}: {overall_score[normalized_query]}")

In [None]:
def public_tests():
    queries = [
    "What is the overall score for taco bell?",
    "What is the overall score for In N Out?",
    "How good is the restaurant Chick-fil-A overall?",
    "What is the overall score for Krispy Kreme?",
    ]
    print(queries)
    query_results = [3.25, 10.000, 10.000, 8.94]
    tolerances = [0.2, 0.2, 0.2, 0.15]
    contents = []
    
    for query in queries:
        with open("runtime-log.txt", "w") as f:
            sys.stdout = f
            main(query)
        with open("runtime-log.txt", "r") as f:
            contents.append(f.read())
            
    restore_prints()
    num_passed = 0
    for i, content in enumerate(contents):
        if not contains_num_with_tolerance(content, query_results[i], tolerance=tolerances[i]):
            print(TerminalColors.RED + f"Test {i+1} Failed." + TerminalColors.RESET, "Expected: ", query_results[i], "Query: ", queries[i])
        else:
            print(TerminalColors.GREEN + f"Test {i+1} Passed." + TerminalColors.RESET, "Expected: ", query_results[i], "Query: ", queries[i])
            num_passed += 1
            
    print(f"{num_passed}/{len(queries)} Tests Passed")