diff --git a/pytext/metric_reporters/word_tagging_metric_reporter.py b/pytext/metric_reporters/word_tagging_metric_reporter.py index e404b63c0..a5723a7bf 100644 --- a/pytext/metric_reporters/word_tagging_metric_reporter.py +++ b/pytext/metric_reporters/word_tagging_metric_reporter.py @@ -11,6 +11,7 @@ AllConfusions, Confusions, LabelPrediction, + MultiLabelSoftClassificationMetrics, PRF1Metrics, compute_classification_metrics, compute_multi_label_multi_class_soft_metrics, @@ -27,6 +28,9 @@ from .metric_reporter import MetricReporter +NAN_LABELS = ["__UNKNOWN__", "__PAD__"] + + def get_slots(word_names): slots = { Node(label=slot.label, span=Span(slot.start, slot.end)) @@ -97,6 +101,8 @@ class MultiLabelSequenceTaggingMetricReporter(MetricReporter): def __init__(self, label_names, pad_idx, channels, label_vocabs=None): super().__init__(channels) self.label_names = label_names + # Right now the assumption is that we use the same pad idx for all + # labels. #TODO Extend it to use multiple label specific pad idxs self.pad_idx = pad_idx self.label_vocabs = label_vocabs @@ -110,8 +116,6 @@ def from_config(cls, config, tensorizers): ) def calculate_metric(self): - if len(self.all_scores) == 0: - return {} list_score_pred_expect = [] for label_idx in range(0, len(self.label_names)): list_score_pred_expect.append( @@ -143,18 +147,13 @@ def batch_context(self, raw_batch, batch): @staticmethod def get_model_select_metric(metrics): - if isinstance(metrics, dict): + if isinstance(metrics, MultiLabelSoftClassificationMetrics): # There are multiclass precision/recall labels # Compute average precision - avg_precision = 0.0 - for _, metric in metrics.items(): - if metric: - avg_precision += sum( - v.average_precision - for k, v in metric.items() - if v.average_precision > 0 - ) / (len(metric.keys()) * 1.0) - avg_precision = avg_precision / (len(metrics.keys()) * 1.0) + normalize_count = sum(1 for k in metrics.average_precision.keys()) * 1.0 + avg_precision = ( + sum(v for k, v in metrics.average_precision.items()) / normalize_count + ) else: avg_precision = metrics.accuracy return avg_precision diff --git a/pytext/metrics/__init__.py b/pytext/metrics/__init__.py index 0c5002b55..cdf894ee6 100644 --- a/pytext/metrics/__init__.py +++ b/pytext/metrics/__init__.py @@ -21,6 +21,7 @@ from pytext.utils.ascii_table import ascii_table +NAN_LABELS = ["__UNKNOWN__", "__PAD__"] RECALL_AT_PRECISION_THRESHOLDS = [0.2, 0.4, 0.6, 0.8, 0.9] PRECISION_AT_RECALL_THRESHOLDS = [0.2, 0.4, 0.6, 0.8, 0.9] @@ -95,6 +96,19 @@ class SoftClassificationMetrics(NamedTuple): roc_auc: Optional[float] +class MultiLabelSoftClassificationMetrics(NamedTuple): + """ + Classification scores that are independent of thresholds. + """ + + average_precision: Dict[str, float] + recall_at_precision: Dict[str, Dict[str, Dict[float, float]]] + decision_thresh_at_precision: Dict[str, Dict[str, Dict[float, float]]] + precision_at_recall: Dict[str, Dict[str, Dict[float, float]]] + decision_thresh_at_recall: Dict[str, Dict[str, Dict[float, float]]] + roc_auc: Optional[Dict[Optional[str], Optional[Dict[str, Optional[float]]]]] + + class MacroPRF1Scores(NamedTuple): """ Macro precision/recall/F1 scores (averages across each label). @@ -757,9 +771,10 @@ def compute_multi_label_multi_class_soft_metrics( predictions: Sequence[Sequence[LabelListPrediction]], label_names: Sequence[str], label_vocabs: Sequence[Sequence[str]], + loss: float, recall_at_precision_thresholds: Sequence[float] = RECALL_AT_PRECISION_THRESHOLDS, precision_at_recall_thresholds: Sequence[float] = PRECISION_AT_RECALL_THRESHOLDS, -) -> Dict[int, SoftClassificationMetrics]: +) -> MultiLabelSoftClassificationMetrics: """ Computes multi-label soft classification metrics with multi-class accommodation @@ -777,10 +792,30 @@ def compute_multi_label_multi_class_soft_metrics( Returns: Dict from label strings to their corresponding soft metrics. """ - soft_metrics = {} + soft_metrics = MultiLabelSoftClassificationMetrics({}, {}, {}, {}, {}, {}) for label_idx, label_vocab in enumerate(label_vocabs): label = list(label_names)[label_idx] - soft_metrics[label] = compute_soft_metrics(predictions[label_idx], label_vocab) + soft_metrics_ = compute_soft_metrics(predictions[label_idx], label_vocab) + temp_avg_precision_ = {k: v.average_precision for k, v in soft_metrics_.items()} + soft_metrics.average_precision[label] = sum( + v for k, v in temp_avg_precision_.items() if k not in NAN_LABELS + ) / ( + sum(1 for k, v in temp_avg_precision_.items() if k not in NAN_LABELS) * 1.0 + ) + soft_metrics.recall_at_precision[label] = { + k: v.recall_at_precision for k, v in soft_metrics_.items() + } + soft_metrics.decision_thresh_at_precision[label] = { + k: v.decision_thresh_at_precision for k, v in soft_metrics_.items() + } + soft_metrics.precision_at_recall[label] = { + k: v.precision_at_recall for k, v in soft_metrics_.items() + } + soft_metrics.decision_thresh_at_recall[label] = { + k: v.decision_thresh_at_recall for k, v in soft_metrics_.items() + } + soft_metrics.roc_auc[label] = {k: v.roc_auc for k, v in soft_metrics_.items()} + return soft_metrics