In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

def evaluate_plagiarism_detector_by_category(model_function, test_data, es, index_name, word2vec_model, bert_model, tokenizer):
    """
    Evaluate the plagiarism detection model's accuracy broken down by category.
    
    Parameters:
    - model_function: Function that classifies plagiarism (classify_plagiarism)
    - test_data: Dictionary with category keys and lists of test sentences
    - es, index_name, word2vec_model, bert_model, tokenizer: Required for the model function
    
    Returns:
    - Dictionary containing accuracy metrics per category
    """
    results_by_category = {}
    overall_true = []
    overall_pred = []
    
    for category, sentences in test_data.items():
        category_true = []
        category_pred = []
        category_results = []
        
        for input_sentence in sentences:
            predicted_class, matched_sentence, features, similarity = model_function(
                input_sentence, es, index_name, word2vec_model, bert_model, tokenizer
            )
            
            category_true.append(category)
            category_pred.append(predicted_class)
            overall_true.append(category)
            overall_pred.append(predicted_class)
            
            category_results.append({
                'input': input_sentence,
                'predicted_label': predicted_class,
                'matched_sentence': matched_sentence,
                'similarity': similarity,
                'features': features
            })
        
        correct_predictions = sum(1 for true, pred in zip(category_true, category_pred) if true == pred)
        accuracy = correct_predictions / len(category_true) if category_true else 0
        

        predictions_as_category = [pred == category for pred in overall_pred]
        if sum(predictions_as_category) > 0:
            precision = sum(1 for true, pred in zip(overall_true, overall_pred) 
                          if pred == category and true == category) / sum(predictions_as_category)
        else:
            precision = 0.0
        
        recall = sum(1 for true, pred in zip(category_true, category_pred) 
                   if pred == category) / len(category_true) if category_true else 0
        

        if precision + recall > 0:
            f1 = 2 * (precision * recall) / (precision + recall)
        else:
            f1 = 0.0
        
        results_by_category[category] = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'total_samples': len(category_true),
            'correct_predictions': correct_predictions,
            'detailed_results': category_results
        }
    

    overall_accuracy = sum(1 for true, pred in zip(overall_true, overall_pred) if true == pred) / len(overall_true) if overall_true else 0
    
    precision, recall, f1, _ = precision_recall_fscore_support(
        overall_true, overall_pred, average='weighted', zero_division=0
    )
    
    labels = sorted(list(set(overall_true + overall_pred)))
    conf_matrix = confusion_matrix(overall_true, overall_pred, labels=labels)
    
    return {
        'categories': results_by_category,
        'overall': {
            'accuracy': overall_accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1
        },
        'confusion_matrix': conf_matrix,
        'labels': labels
    }

test_data_by_category = {
    "Cut-Paste": [
        "artificial intelligence is transforming industries by automating tasks and improving efficiency",
        "climate change is a global challenge that requires immediate attention to mitigate its adverse effects",
        "the rapid advancement of technology has significantly impacted communication and information sharing",
        "renewable energy sources such as solar and wind are crucial for reducing dependence on fossil fuels",
        "data privacy and security have become major concerns in the digital age requiring stringent measures"
    ],
    "Light Paraphrasing": [
        "artificial intelligence is changing industries by automating tasks and improving efficiency",
        "climate change is a worldwide challenge that needs immediate action to reduce its negative effects",
        "the quick advancement of tech has greatly affected communication and information exchange",
        "renewable power sources like solar and wind are important for decreasing reliance on fossil fuels",
        "data security and privacy have become key issues in the digital era requiring strong measures"
    ],
    "Heavy Paraphrasing": [
        "AI technologies are revolutionizing how businesses operate by handling repetitive work and boosting productivity",
        "global warming presents a critical worldwide issue demanding prompt response to limit harmful consequences",
        "technological progress has fundamentally altered how people communicate and share information",
        "alternative energy like solar panels and wind turbines help reduce our dependence on oil and coal",
        "protecting personal information is increasingly critical as digital systems become more integrated in society"
    ],
    "No Match": [
        "the stock market showed significant growth in the technology sector last quarter",
        "proper nutrition and regular exercise are essential for maintaining good health",
        "the film won several awards for its exceptional screenplay and cinematography",
        "the space mission successfully landed on the moon and collected valuable samples",
        "the university announced a new scholarship program for international students"
    ]
}

results = evaluate_plagiarism_detector_by_category(
    classify_plagiarism, test_data_by_category, es, index_name, word2vec_model, bert_model, tokenizer
)

print("=" * 50)
print("OVERALL MODEL PERFORMANCE")
print("=" * 50)
print(f"Accuracy:  {results['overall']['accuracy']:.4f}")
print(f"Precision: {results['overall']['precision']:.4f}")
print(f"Recall:    {results['overall']['recall']:.4f}")
print(f"F1 Score:  {results['overall']['f1_score']:.4f}")
print("=" * 50)

print("\nConfusion Matrix:")
print(f"Labels: {results['labels']}")
for i, row in enumerate(results['confusion_matrix']):
    print(f"{results['labels'][i]}: {row}")

print("\n=== Per-Category Performance ===")
for category in ["Cut-Paste", "Light Paraphrasing", "Heavy Paraphrasing", "No Match"]:
    if category in results['categories']:
        metrics = results['categories'][category]
        print(f"\n{category}:")
        print(f"  Accuracy: {metrics['accuracy']:.4f} ({metrics['correct_predictions']}/{metrics['total_samples']})")
        print(f"  Precision: {metrics['precision']:.4f}")
        print(f"  Recall: {metrics['recall']:.4f}")
        print(f"  F1-Score: {metrics['f1_score']:.4f}")

print("\n=== Summary Table ===")
print(f"{'Category':<20} {'Accuracy':<10} {'Precision':<10} {'Recall':<10} {'F1-Score':<10}")
print("-" * 60)
for category in ["Cut-Paste", "Light Paraphrasing", "Heavy Paraphrasing", "No Match"]:
    if category in results['categories']:
        metrics = results['categories'][category]
        print(f"{category:<20} {metrics['accuracy']:.4f}     {metrics['precision']:.4f}     {metrics['recall']:.4f}     {metrics['f1_score']:.4f}")
print("-" * 60)
print(f"{'OVERALL':<20} {results['overall']['accuracy']:.4f}     {results['overall']['precision']:.4f}     {results['overall']['recall']:.4f}     {results['overall']['f1_score']:.4f}")

for category in ["Cut-Paste", "Light Paraphrasing", "Heavy Paraphrasing", "No Match"]:
    if category in results['categories']:
        print(f"\n=== Detailed Results for {category} ===")
        for i, result in enumerate(results['categories'][category]['detailed_results']):
            print(f"\nTest Case {i+1}:")
            print(f"Input: {result['input']}")
            print(f"Predicted Label: {result['predicted_label']}")
            matched = result['matched_sentence'] if result['matched_sentence'] else 'None'
            print(f"Matched: {matched}")
            print(f"Similarity: {result['similarity']:.2f}%")
            features_str = str(result['features']) if result['features'] is not None else 'None'
            print(f"Features: {features_str}")