Skip to content

Commit

Permalink
Add custom metric class for reporting Joint model metrics (facebookre…
Browse files Browse the repository at this point in the history
…search#1339)

Summary:
Pull Request resolved: facebookresearch#1339

Adding a multilabel metric class to support reporting all multi label and multi class metrics joint pytext models

Reviewed By: seayoung1112

Differential Revision: D21077306

fbshipit-source-id: c868433b8814a8a3b99373ee51e1fe010658af90
  • Loading branch information
shivanipods authored and facebook-github-bot committed Apr 29, 2020
1 parent 0b8b38e commit 8bd2f75
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 15 deletions.
23 changes: 11 additions & 12 deletions pytext/metric_reporters/word_tagging_metric_reporter.py
Expand Up @@ -11,6 +11,7 @@
AllConfusions,
Confusions,
LabelPrediction,
MultiLabelSoftClassificationMetrics,
PRF1Metrics,
compute_classification_metrics,
compute_multi_label_multi_class_soft_metrics,
Expand All @@ -27,6 +28,9 @@
from .metric_reporter import MetricReporter


NAN_LABELS = ["__UNKNOWN__", "__PAD__"]


def get_slots(word_names):
slots = {
Node(label=slot.label, span=Span(slot.start, slot.end))
Expand Down Expand Up @@ -97,6 +101,8 @@ class MultiLabelSequenceTaggingMetricReporter(MetricReporter):
def __init__(self, label_names, pad_idx, channels, label_vocabs=None):
super().__init__(channels)
self.label_names = label_names
# Right now the assumption is that we use the same pad idx for all
# labels. #TODO Extend it to use multiple label specific pad idxs
self.pad_idx = pad_idx
self.label_vocabs = label_vocabs

Expand All @@ -110,8 +116,6 @@ def from_config(cls, config, tensorizers):
)

def calculate_metric(self):
if len(self.all_scores) == 0:
return {}
list_score_pred_expect = []
for label_idx in range(0, len(self.label_names)):
list_score_pred_expect.append(
Expand Down Expand Up @@ -143,18 +147,13 @@ def batch_context(self, raw_batch, batch):

@staticmethod
def get_model_select_metric(metrics):
if isinstance(metrics, dict):
if isinstance(metrics, MultiLabelSoftClassificationMetrics):
# There are multiclass precision/recall labels
# Compute average precision
avg_precision = 0.0
for _, metric in metrics.items():
if metric:
avg_precision += sum(
v.average_precision
for k, v in metric.items()
if v.average_precision > 0
) / (len(metric.keys()) * 1.0)
avg_precision = avg_precision / (len(metrics.keys()) * 1.0)
normalize_count = sum(1 for k in metrics.average_precision.keys()) * 1.0
avg_precision = (
sum(v for k, v in metrics.average_precision.items()) / normalize_count
)
else:
avg_precision = metrics.accuracy
return avg_precision
Expand Down
41 changes: 38 additions & 3 deletions pytext/metrics/__init__.py
Expand Up @@ -21,6 +21,7 @@
from pytext.utils.ascii_table import ascii_table


NAN_LABELS = ["__UNKNOWN__", "__PAD__"]
RECALL_AT_PRECISION_THRESHOLDS = [0.2, 0.4, 0.6, 0.8, 0.9]
PRECISION_AT_RECALL_THRESHOLDS = [0.2, 0.4, 0.6, 0.8, 0.9]

Expand Down Expand Up @@ -95,6 +96,19 @@ class SoftClassificationMetrics(NamedTuple):
roc_auc: Optional[float]


class MultiLabelSoftClassificationMetrics(NamedTuple):
"""
Classification scores that are independent of thresholds.
"""

average_precision: Dict[str, float]
recall_at_precision: Dict[str, Dict[str, Dict[float, float]]]
decision_thresh_at_precision: Dict[str, Dict[str, Dict[float, float]]]
precision_at_recall: Dict[str, Dict[str, Dict[float, float]]]
decision_thresh_at_recall: Dict[str, Dict[str, Dict[float, float]]]
roc_auc: Optional[Dict[Optional[str], Optional[Dict[str, Optional[float]]]]]


class MacroPRF1Scores(NamedTuple):
"""
Macro precision/recall/F1 scores (averages across each label).
Expand Down Expand Up @@ -757,9 +771,10 @@ def compute_multi_label_multi_class_soft_metrics(
predictions: Sequence[Sequence[LabelListPrediction]],
label_names: Sequence[str],
label_vocabs: Sequence[Sequence[str]],
loss: float,
recall_at_precision_thresholds: Sequence[float] = RECALL_AT_PRECISION_THRESHOLDS,
precision_at_recall_thresholds: Sequence[float] = PRECISION_AT_RECALL_THRESHOLDS,
) -> Dict[int, SoftClassificationMetrics]:
) -> MultiLabelSoftClassificationMetrics:
"""
Computes multi-label soft classification metrics with multi-class accommodation
Expand All @@ -777,10 +792,30 @@ def compute_multi_label_multi_class_soft_metrics(
Returns:
Dict from label strings to their corresponding soft metrics.
"""
soft_metrics = {}
soft_metrics = MultiLabelSoftClassificationMetrics({}, {}, {}, {}, {}, {})
for label_idx, label_vocab in enumerate(label_vocabs):
label = list(label_names)[label_idx]
soft_metrics[label] = compute_soft_metrics(predictions[label_idx], label_vocab)
soft_metrics_ = compute_soft_metrics(predictions[label_idx], label_vocab)
temp_avg_precision_ = {k: v.average_precision for k, v in soft_metrics_.items()}
soft_metrics.average_precision[label] = sum(
v for k, v in temp_avg_precision_.items() if k not in NAN_LABELS
) / (
sum(1 for k, v in temp_avg_precision_.items() if k not in NAN_LABELS) * 1.0
)
soft_metrics.recall_at_precision[label] = {
k: v.recall_at_precision for k, v in soft_metrics_.items()
}
soft_metrics.decision_thresh_at_precision[label] = {
k: v.decision_thresh_at_precision for k, v in soft_metrics_.items()
}
soft_metrics.precision_at_recall[label] = {
k: v.precision_at_recall for k, v in soft_metrics_.items()
}
soft_metrics.decision_thresh_at_recall[label] = {
k: v.decision_thresh_at_recall for k, v in soft_metrics_.items()
}
soft_metrics.roc_auc[label] = {k: v.roc_auc for k, v in soft_metrics_.items()}

return soft_metrics


Expand Down

0 comments on commit 8bd2f75

Please sign in to comment.