In [2]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_fscore_support
import seaborn as sns
import numpy as np
from itertools import chain

In [3]:
class CRFEvaluator:
    def __init__(self, model):
        self.model = model
        self.labels = ['B-Quantity', 'B-Pizza', 'I-Pizza', 'B-Topping', 'B-Size', 'I-Size', 'O', 'B-Crust', 'I-Crust']

    def evaluate(self, test_data, evaluation_type):
        if evaluation_type == 'confusion_matrix':
            self._plot_confusion_matrix(test_data)
        elif evaluation_type == 'classification_report':
            self._print_classification_report(test_data)
        else:
            raise ValueError("Unsupported evaluation type. Supported types are 'confusion_matrix' and 'classification_report'.")

    def _confusion_matrix(self, test_data):
        y_test, y_pred = self._get_predictions(test_data)
        cm = confusion_matrix(y_test, y_pred, labels=self.labels)
        fig, ax = plt.subplots(figsize=(10, 10))
        sns.heatmap(cm, annot=True, fmt='d', ax=ax, cmap="Blues", xticklabels=self.labels, yticklabels=self.labels)
        plt.ylabel('Actual')
        plt.xlabel('Predicted')
        plt.title('Confusion Matrix')
        plt.show()

    def _classification_report(self, test_data):
        y_test, y_pred = self._get_predictions(test_data)
        precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, labels=self.labels, zero_division=0)
        
        x = np.arange(len(self.labels))
        width = 0.3

        fig, ax = plt.subplots()
        rects1 = ax.bar(x - width, precision, width, label='Precision')
        rects2 = ax.bar(x, recall, width, label='Recall')
        rects3 = ax.bar(x + width, f1, width, label='F1-score')

        ax.set_ylabel('Scores')
        ax.set_title('Scores by group and evaluation metric')
        ax.set_xticks(x)
        ax.set_xticklabels(self.labels, rotation=45, ha='right')
        ax.legend()

        fig.tight_layout()
        plt.show()
