Skip to content

Commit

Permalink
Merge pull request #1019 from CatalinVoss/scorer_cleanup
Browse files Browse the repository at this point in the history
Scorer cleanup
  • Loading branch information
ajratner committed Oct 25, 2018
2 parents 97591ff + 9fb42b0 commit a5f5201
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions snorkel/learning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,30 @@ def get_train_idxs(self, rebalance=False, split=0.5, rand_state=None):

class Scorer(object):
"""Abstract type for scorers"""
def __init__(self, test_candidates, test_labels, gold_candidate_set=None):
def __init__(self, test_candidates, test_labels, gold_candidate_set=None,
cardinality=None):
"""
:param test_candidates: A *list of Candidates* corresponding to
test_labels
:param test_labels: A *csrLabelMatrix* of ground truth labels for the
test candidates
:param gold_candidate_set: (optional) A *CandidateSet* containing the
full set of gold labeled candidates
:param cardinality: An *int* overwrite for the automatic cardinality
inference.
"""
self.test_candidates = test_candidates
self.test_labels = test_labels
self.gold_candidate_set = gold_candidate_set
self.cardinality = cardinality

def _get_cardinality(self, marginals):
"""Get the cardinality based on the marginals returned by the model."""
"""
Get the cardinality based on what we were told or the marginals returned
by the model.
"""
if self.cardinality:
return self.cardinality
if len(marginals.shape) == 1 or marginals.shape[1] < 3:
cardinality = 2
else:
Expand All @@ -124,12 +133,11 @@ def score(self, test_marginals, **kwargs):
else:
return self._score_categorical(test_marginals, **kwargs)

def _score_binary(self, test_marginals, train_marginals=None, b=0.5,
set_unlabeled_as_neg=True, display=True):
def _score_binary(self, test_marginals, b=0.5, set_unlabeled_as_neg=True,
display=True):
raise NotImplementedError()

def _score_categorical(self, test_marginals, train_marginals=None,
display=True):
def _score_categorical(self, test_marginals, display=True):
raise NotImplementedError()

def summary_score(self, test_marginals, **kwargs):
Expand All @@ -139,22 +147,19 @@ def summary_score(self, test_marginals, **kwargs):

class MentionScorer(Scorer):
"""Scorer for mention level assessment"""
def _score_binary(self, test_marginals, train_marginals=None, b=0.5,
set_unlabeled_as_neg=True, set_at_thresh_as_neg=True, display=True,
**kwargs):
def _score_binary(self, test_marginals, b=0.5, set_unlabeled_as_neg=True,
set_at_thresh_as_neg=True, display=True, **kwargs):
"""
Return scoring metric for the provided marginals, as well as candidates
in error buckets.
:param test_marginals: array of marginals for test candidates
:param train_marginals (optional): array of marginals for training
candidates
:param b: threshold for labeling
:param set_unlabeled_as_neg: set test labels at the decision threshold
of b as negative labels
:param set_at_b_as_neg: set marginals at the decision threshold exactly
as negative predictions
:param display: show calibration plots?
:param display: print stats?
Note that even when the test_marginals are in the range [0, 1] (like our
default b assumes), we still require the test_label to be in the set
Expand Down Expand Up @@ -193,7 +198,6 @@ def _score_binary(self, test_marginals, train_marginals=None, b=0.5,
else:
fn.add(candidate)
if display:

# Calculate scores unadjusted for TPs not in our candidate set
print_scores(len(tp), len(fp), len(tn), len(fn),
title="Scores (Un-adjusted)")
Expand All @@ -205,31 +209,25 @@ def _score_binary(self, test_marginals, train_marginals=None, b=0.5,
print("\n")
print_scores(len(tp), len(fp), len(tn), len(fn)+len(gold_fn),
title="Corpus Recall-adjusted Scores")

# If training and test marginals provided print calibration plots
if train_marginals is not None and test_marginals is not None:
print("\nCalibration plot:")
calibration_plots(train_marginals, test_marginals,
np.asarray(test_label_array))
return tp, fp, tn, fn

def _score_categorical(self, test_marginals, train_marginals=None,
display=True, **kwargs):
def _score_categorical(self, test_marginals, display=True, **kwargs):
"""
Return scoring metric for the provided marginals, as well as candidates
in error buckets.
:param test_marginals: array of marginals for test candidates
:param train_marginals (optional): array of marginals for training
candidates
:param display: show calibration plots?
:param display: print stats?
"""
test_label_array = []
correct = set()
incorrect = set()

# Get predictions
test_pred = test_marginals.argmax(axis=1) + 1
if len(test_marginals.shape) > 1:
test_pred = test_marginals.argmax(axis=1) + 1
else:
test_pred = test_marginals

# Bucket the candidates for error analysis
for i, candidate in enumerate(self.test_candidates):
Expand Down

0 comments on commit a5f5201

Please sign in to comment.