## This is the script to Evaluate Email-Label Pairs
### Criteria:
1. Spam: ACC, F1, Recall,
2. Time_Sensitive: ACC, F1, Recall,
3. Time Period ACC
4. Type ACC
5. Category ACC
6. Format ACC
7. Priority_Level ACC (Should not be so strict)
8. Overall: weighted sum score?

### Field:
{

    "Spam": "Yes" / "No",
    "Subject": ,
    "Sender": ,
    "send_date": ,
    "Time_Sensitive": "Yes" / "No",
    "Start": ,
    "End": ,
    "Type": "Event" / "Reminder" / "N/A",
    "Category": "Work" / "Study" / "Leisure",
    "Format": "Online" / "In-person",
    "Location": ,
    "Action_Required": "Yes" / "No",
    "Priority_Level": "Low" / "Medium" / "High" / "Urgent" 
    
},

In [75]:
from sklearn.metrics import accuracy_score, f1_score, recall_score
from datetime import datetime
import json

In [76]:
# from utils.evaluation_utils import evaluate_label_single, calculate_overall_metrics

In [77]:
def is_time_period_correct(pred_start, pred_end, true_start, true_end):
    """
    Strictly compare predicted and true time periods.

    Parameters:
        pred_start (str): Predicted start time as a string (e.g., "2024-11-20 09:17").
        pred_end (str): Predicted end time as a string (e.g., "2024-11-20 11:25").
        true_start (str): True start time as a string.
        true_end (str): True end time as a string.

    Returns:
        bool: True if both start and end times match exactly, False otherwise.
    """
    # Convert strings to datetime objects
    pred_start_dt = datetime.strptime(pred_start, "%Y-%m-%d %H:%M")
    pred_end_dt = datetime.strptime(pred_end, "%Y-%m-%d %H:%M")
    true_start_dt = datetime.strptime(true_start, "%Y-%m-%d %H:%M")
    true_end_dt = datetime.strptime(true_end, "%Y-%m-%d %H:%M")

    # Check for exact match
    return pred_start_dt == true_start_dt and pred_end_dt == true_end_dt

def evaluate_label(pred, true, weights=None):
    """
    Evaluate categorical fields and custom logic for Time Period and Priority Level.

    Parameters:
        pred (list of dict): Predicted labels.
        true (list of dict): True labels.
        weights (dict): Weights for each field except Spam.

    Returns:
        dict: Evaluation results for each field and overall score.
    """
    if weights is None:
        # Default weights (Spam is NOT included in the score)
        weights = {
            "Time_Sensitive": 0.2,
            "Type": 0.15,
            "Category": 0.15,
            "Format": 0.1,
            "Time Period": 0.2,
            "Priority_Level": 0.2,
        }

    # Initialize results
    results = {"Field": {}, "Overall Weighted Score": 0}
    spam_correct = 1  # Default to correct unless proven otherwise

    for pred_item, true_item in zip(pred, true):
        # print(true_item)
        for field in true_item.keys():
            if field in pred_item:
                if field in ["Time_Sensitive", "Type", "Category", "Format"]:
                    # Categorical fields
                    results["Field"][field] = 1 if pred_item[field] == true_item[field] else 0
                elif field in ["Start", "End"]:
                    # Time Period
                    # time_correct = (
                    #     (pred_item["Start"], pred_item["End"]) == (true_item["Start"], true_item["End"])
                    # )
                    time_correct = is_time_period_correct(
                        pred_item["Start"],
                        pred_item["End"],
                        true_item["Start"],
                        true_item["End"]
                    )
                    results["Field"]["Time Period"] = 1 if time_correct else 0
                elif field == "Priority_Level":
                    # Priority Level (Relaxed Match)
                    priority_map = {"Low": 1, "Medium": 2, "High": 3, "Urgent": 4}
                    pred_priority = priority_map[pred_item[field]]
                    true_priority = priority_map[true_item[field]]
                    results["Field"]["Priority_Level"] = 1 if abs(pred_priority - true_priority) <= 1 else 0
                elif field == "Spam":
                    # Spam correctness (Denominator)
                    spam_correct = 1 if pred_item[field] == true_item[field] else 0
                    results["Field"]["Spam"] = spam_correct

        # Calculate weighted score using fields except Spam
        weighted_score = sum(
            results["Field"].get(field, 0) * weights.get(field, 0)
            for field in weights.keys()
        )
        results["Overall Weighted Score"] = weighted_score if spam_correct == 1 else 0

    return results

In [None]:
def evaluate_label_single(pred, true, weights=None):
    """
    Evaluate a single pair of prediction and true label.

    Parameters:
        pred (dict): Predicted label.
        true (dict): True label.
        weights (dict): Weights for each field except Spam.

    Returns:
        dict: Evaluation results for each field and overall score.
    """
    if weights is None:
        # Default weights (Spam is NOT included in the score)
        weights = {
            "Time_Sensitive": 0.2,
            "Type": 0.15,
            "Category": 0.15,
            "Format": 0.1,
            "Time Period": 0.2,
            "Priority_Level": 0.2,
        }

    # Initialize results
    results = {"Field": {}, "Overall Weighted Score": 0}

    # Spam correctness (Denominator)
    spam_correct = 1 if pred["Spam"] == true["Spam"] else 0
    results["Field"]["Spam"] = spam_correct

    # Evaluate fields
    for field in true.keys():
        if field in ["Time_Sensitive", "Type", "Category", "Format"]:
            # Categorical fields
            results["Field"][field] = 1 if pred[field] == true[field] else 0
        elif field in ["Start", "End"]:
            # Time Period
            time_correct = is_time_period_correct(
                pred["Start"], pred["End"], true["Start"], true["End"]
            )
            results["Field"]["Time Period"] = 1 if time_correct else 0
        elif field == "Priority_Level":
            # Priority Level (Relaxed Match)
            priority_map = {"Low": 1, "Medium": 2, "High": 3, "Urgent": 4}
            pred_priority = priority_map[pred[field]]
            true_priority = priority_map[true[field]]
            results["Field"]["Priority_Level"] = 1 if abs(pred_priority - true_priority) <= 1 else 0

    # Calculate weighted score using fields except Spam
    weighted_score = sum(
        results["Field"].get(field, 0) * weights.get(field, 0)
        for field in weights.keys()
    )
    results["Overall Weighted Score"] = weighted_score if spam_correct == 1 else 0

    return results

In [79]:
def calculate_overall_metrics(results_list):
    """
    Calculate overall ACC, F1, and Recall for binary fields, ACC for categorical fields,
    and an averaged weighted score.

    Parameters:
        results_list (list of dict): List of evaluation results for individual predictions.

    Returns:
        dict: Overall metrics for binary fields, categorical fields, and average weighted score.
    """
    # Initialize storage
    binary_fields = ["Spam", "Time_Sensitive"]
    categorical_fields = ["Time Period", "Type", "Category", "Format", "Priority_Level"]

    binary_true = {field: [] for field in binary_fields}
    binary_pred = {field: [] for field in binary_fields}

    categorical_correct = {field: 0 for field in categorical_fields}
    categorical_total = {field: 0 for field in categorical_fields}

    total_weighted_score = 0
    num_results = len(results_list)

    # Aggregate results
    for result in results_list:
        total_weighted_score += result["Overall Weighted Score"]
        for field in result["Field"]:
            if field in binary_fields:
                binary_true[field].append(1)  # True value is always "correct"
                binary_pred[field].append(result["Field"][field])  # Append prediction (0 or 1)
            elif field in categorical_fields:
                categorical_total[field] += 1
                if result["Field"][field] == 1:
                    categorical_correct[field] += 1

    # Calculate metrics for binary fields
    binary_metrics = {}
    for field in binary_fields:
        binary_metrics[field] = {
            "ACC": accuracy_score(binary_true[field], binary_pred[field]),
            "F1": f1_score(binary_true[field], binary_pred[field]),
            "Recall": recall_score(binary_true[field], binary_pred[field]),
        }

    # Calculate accuracy for categorical fields
    categorical_metrics = {
        field: categorical_correct[field] / categorical_total[field]
        for field in categorical_fields
    }

    # Calculate averaged weighted score
    averaged_weighted_score = total_weighted_score / num_results

    return {
        "Binary Metrics": binary_metrics,
        "Categorical Metrics": categorical_metrics,
        "Averaged Weighted Score": averaged_weighted_score,
    }

In [80]:
pred_path = "./data/test_examples/test_pred.json"
true_path = "./data/test_examples/test_true.json"
with open(pred_path, 'r') as file:
    pred_labels = json.load(file)
with open(true_path, 'r') as file:
    true_labels = json.load(file)


In [81]:
results = []
assert len(pred_labels) == len(true_labels)
for i in range(len(pred_labels)):
    results.append(evaluate_label_single(pred_labels[i], true_labels[i]))

In [82]:
results

[{'Field': {'Spam': 1,
   'Time_Sensitive': 1,
   'Time Period': 0,
   'Type': 1,
   'Category': 1,
   'Format': 1,
   'Priority_Level': 1},
  'Overall Weighted Score': 0.8},
 {'Field': {'Spam': 1,
   'Time_Sensitive': 1,
   'Time Period': 1,
   'Type': 1,
   'Category': 1,
   'Format': 1,
   'Priority_Level': 1},
  'Overall Weighted Score': 1.0},
 {'Field': {'Spam': 1,
   'Time_Sensitive': 1,
   'Time Period': 1,
   'Type': 1,
   'Category': 1,
   'Format': 1,
   'Priority_Level': 1},
  'Overall Weighted Score': 1.0},
 {'Field': {'Spam': 0,
   'Time_Sensitive': 1,
   'Time Period': 1,
   'Type': 1,
   'Category': 1,
   'Format': 1,
   'Priority_Level': 1},
  'Overall Weighted Score': 0},
 {'Field': {'Spam': 1,
   'Time_Sensitive': 1,
   'Time Period': 1,
   'Type': 1,
   'Category': 1,
   'Format': 1,
   'Priority_Level': 1},
  'Overall Weighted Score': 1.0}]

In [83]:
metrics = calculate_overall_metrics(results)
metrics

{'Binary Metrics': {'Spam': {'ACC': 0.8,
   'F1': 0.8888888888888888,
   'Recall': 0.8},
  'Time_Sensitive': {'ACC': 1.0, 'F1': 1.0, 'Recall': 1.0}},
 'Categorical Metrics': {'Time Period': 0.8,
  'Type': 1.0,
  'Category': 1.0,
  'Format': 1.0,
  'Priority_Level': 1.0},
 'Averaged Weighted Score': 0.76}