In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
)

# Load the CSV file
file_path = 'predictions/filtered_relations_test_clean_removed_conclusion_source_predictions.csv'
data = pd.read_csv(file_path)

# Extract actual and predicted relations
actual_relations = data['actual_relation']
predicted_relations = data['predicted_relation']

# Define the relation types
relation_types = ["attack", "no-relation", "support" ]

def save_classification_report(predictions, labels, output_dir="./"):
    report = classification_report(
        labels, predictions,
        target_names=relation_types
    )
    
    with open(f"{output_dir}/classification_report.txt", "w") as f:
        f.write(report)
    
    print("Classification Report:")
    print(report)

def plot_confusion_matrix(y_true, y_pred, output_path="confusion_matrix.png"):
    cm = confusion_matrix(y_true, y_pred, labels=relation_types)
    
    # Create labels for each cell showing count
    labels = np.empty_like(cm, dtype=object)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            labels[i, j] = f"{cm[i, j]}"
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=labels, fmt='', cmap="Blues",
                xticklabels=[f"Predicted {rel}" for rel in relation_types],
                yticklabels=[f"Actual {rel}" for rel in relation_types])
    plt.xlabel('Prediction')
    plt.ylabel('Ground Truth')
    plt.title('Confusion Matrix for Relation Classification (gemma-3-27b-it, without case file context, zero-shot)')
    plt.tight_layout()
    plt.savefig(output_path)
    plt.show()

# Generate classification report and confusion matrix
save_classification_report(predicted_relations, actual_relations)
plot_confusion_matrix(actual_relations, predicted_relations)

# Display the confusion matrix as a DataFrame for better readability
cm_df = pd.DataFrame(
    confusion_matrix(actual_relations, predicted_relations, labels=relation_types),
    index=[f"Actual {rel}" for rel in relation_types],
    columns=[f"Predicted {rel}" for rel in relation_types]
)

print("\nConfusion Matrix:")
print(cm_df)
