Skip to content

Commit

Permalink
user can specify the size of negatives to generate
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatYYX committed Aug 2, 2018
1 parent 29e6099 commit c084d28
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions rltk/evaluation/ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,16 @@ def __next__(self):
def __len__(self):
return len(self._ground_truth_data)

def generate_negatives(self, dataset1: 'Dataset', dataset2: 'Dataset', score_function: Callable, size=-1):
size = len(self) if size == -1 else size # same size as positives
def generate_negatives(self, dataset1: 'Dataset', dataset2: 'Dataset',
score_function: Callable, num_of_negatives: int = -1):
num_of_negatives = len(self) if num_of_negatives == -1 else num_of_negatives
max_heap = []

for r1, r2 in get_record_pairs(dataset1, dataset2):
if not self.is_member(r1.id, r2.id):
s = score_function(r1, r2)
heapq.heappush(max_heap, (s, r1.id, r2.id))
if len(max_heap) > size:
if len(max_heap) > num_of_negatives:
heapq.heappop(max_heap)

for d in max_heap:
Expand All @@ -187,7 +188,8 @@ def generate_all_negatives(self, dataset1: 'Dataset', dataset2: 'Dataset'):
self.add_negative(r1.id, r2.id)

def generate_stratified_negatives(self, dataset1: 'Dataset', dataset2: 'Dataset',
classify: Callable, num_of_strata: int, random_seed: int = None):
classify: Callable, num_of_strata: int, random_seed: int = None,
num_of_negatives: int = -1):

# add positives and negatives to different clusters
strata = [{'p': [], 'n': []} for _ in range(num_of_strata)]
Expand All @@ -212,7 +214,7 @@ def generate_stratified_negatives(self, dataset1: 'Dataset', dataset2: 'Dataset'
sorted_strata_weights = OrderedDict(sorted(strata_weights.items(), key=itemgetter(1), reverse=True))

# find out the number of negatives to pick from each stratum
total_num = sum([len(s['p']) for s in strata])
total_num = sum([len(s['p']) for s in strata]) if not num_of_negatives else num_of_negatives
num_to_pick_from_each_stratum = [0] * num_of_strata
curr_strata_weights = copy.deepcopy(sorted_strata_weights)
for stratum_id in sorted_strata_weights.keys():
Expand Down

0 comments on commit c084d28

Please sign in to comment.