Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade sampler capability #209

Merged
merged 34 commits into from Dec 13, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
10b62d9
Add corruption scheme configuration possibility for negative samplers
lvermue Dec 10, 2020
6b48579
Fix variable naming errors
lvermue Dec 10, 2020
5d47a41
Remove unnecessary code
lvermue Dec 10, 2020
f3ff3cf
Add negative sample filtering
lvermue Dec 10, 2020
d7e108a
Trigger CI
PyKEEN-bot Dec 11, 2020
ade1143
Update basic_negative_sampler.py
cthoyt Dec 11, 2020
396f27d
Merge branch 'master' into upgrade_sampler
cthoyt Dec 11, 2020
1f006ed
Update docstrings
lvermue Dec 11, 2020
6695397
Add negative samples filtering for pairwise losses
lvermue Dec 11, 2020
dee3fa4
Trigger CI
PyKEEN-bot Dec 11, 2020
39ae601
Code refactoring
lvermue Dec 12, 2020
07b784c
Merge branch 'upgrade_sampler' of https://github.com/pykeen/pykeen in…
lvermue Dec 12, 2020
b591ba9
Update sampler tests
lvermue Dec 12, 2020
5e8c020
Fix handling of Bernoulli sampler
lvermue Dec 12, 2020
b7f9fe1
Add filtering for Bernoulli sampler and refactoring
lvermue Dec 12, 2020
a0d3e4f
Fix flake8
lvermue Dec 12, 2020
2c5b717
Add unit tests for negative sample filtering
lvermue Dec 12, 2020
20f333d
Fix flake8
lvermue Dec 12, 2020
cb47432
Trigger CI
PyKEEN-bot Dec 12, 2020
7af702d
Update src/pykeen/sampling/basic_negative_sampler.py
lvermue Dec 12, 2020
d3852cd
Code refactoring
lvermue Dec 12, 2020
e346db8
Add lookup dict to list comprehension
lvermue Dec 12, 2020
bf50831
Update bernoulli_negative_sampler.py
cthoyt Dec 12, 2020
97a76fc
Make filter_negative_triples public
cthoyt Dec 12, 2020
9ec8b14
Pass flake8
cthoyt Dec 12, 2020
8087c53
Simplify slicing
lvermue Dec 12, 2020
d687519
Merge branch 'upgrade_sampler' of https://github.com/pykeen/pykeen in…
lvermue Dec 12, 2020
94e6140
Update docstring
lvermue Dec 12, 2020
593a786
Trigger CI
PyKEEN-bot Dec 12, 2020
d28ad30
Code refactoring
lvermue Dec 13, 2020
a5825ac
Fix flake8
lvermue Dec 13, 2020
a0766fe
Refactor code
lvermue Dec 13, 2020
c8ca629
Adjust unit tests to new structure
lvermue Dec 13, 2020
8500c83
Trigger CI
PyKEEN-bot Dec 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 45 additions & 15 deletions src/pykeen/sampling/basic_negative_sampler.py
Expand Up @@ -2,6 +2,8 @@

"""Negative sampling algorithm based on the work of of Bordes *et al.*."""

from typing import Optional, Tuple

import torch

from .negative_sampler import NegativeSampler
Expand Down Expand Up @@ -31,31 +33,59 @@ class BasicNegativeSampler(NegativeSampler):
num_negs_per_pos=dict(type=int, low=1, high=100, q=10),
)

def sample(self, positive_batch: torch.LongTensor) -> torch.LongTensor:
def sample(self, positive_batch: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.Tensor]]:
"""Generate negative samples from the positive batch."""
if self.num_negs_per_pos > 1:
positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1)

# Bind number of negatives to sample
num_negs = positive_batch.shape[0]

# Equally corrupt head and tail
split_idx = num_negs // 2
# Equally corrupt all sides
split_idx = num_negs // len(self._corruption_indices)

# Copy positive batch for corruption.
# Do not detach, as no gradients should flow into the indices.
negative_batch = positive_batch.clone()

# Sample random entities as replacement
negative_entities = torch.randint(high=self.num_entities - 1, size=(num_negs,), device=positive_batch.device)

# Replace heads – To make sure we don't replace the head by the original value
# we shift all values greater or equal than the original value by one up
# for that reason we choose the random value from [0, num_entities -1]
filter_same_head = (negative_entities[:split_idx] >= positive_batch[:split_idx, 0])
negative_batch[:split_idx, 0] = negative_entities[:split_idx] + filter_same_head.long()
# Corrupt tails
filter_same_tail = (negative_entities[split_idx:] >= positive_batch[split_idx:, 2])
negative_batch[split_idx:, 2] = negative_entities[split_idx:] + filter_same_tail.long()

return negative_batch
if 0 in self._corruption_indices or 2 in self._corruption_indices:
negative_entities = torch.randint(
high=self.num_entities - 1,
size=(num_negs,),
device=positive_batch.device,
)
mberr marked this conversation as resolved.
Show resolved Hide resolved

# Sample random relations as replacement, if requested
if 1 in self._corruption_indices:
negative_relations = torch.randint(
high=self.num_relations - 1,
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
size=(num_negs,),
device=positive_batch.device,
)

for index, start in zip(self._corruption_indices, range(0, num_negs, split_idx)):
stop = min(start + split_idx, num_negs)
if index == 1:
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
# Corrupt relations
negative_batch[start:stop, index] = negative_relations[start:stop]
else:
# Corrupt heads or tails
negative_batch[start:stop, index] = negative_entities[start:stop]

# Replace {heads, relations, tails} – To make sure we don't replace the {head, relation, tail} by the
# original value we shift all values greater or equal than the original value by one up
# for that reason we choose the random value from [0, num_{heads, relations, tails} -1]
if not self.filtered:
negative_batch[start:stop, index] += (
negative_batch[start:stop, index] >= positive_batch[start:stop, index]
).long()

# If filtering is activated, all negative triples that are positive in the training dataset will be removed
if self.filtered:
batch_filter = self._filter_negative_triples(negative_batch=negative_batch)
negative_batch = negative_batch[batch_filter]
mberr marked this conversation as resolved.
Show resolved Hide resolved
else:
batch_filter = None

return negative_batch, batch_filter
49 changes: 47 additions & 2 deletions src/pykeen/sampling/negative_sampler.py
Expand Up @@ -3,7 +3,7 @@
"""Basic structure for a negative sampler."""

from abc import ABC, abstractmethod
from typing import Any, ClassVar, Mapping, Optional
from typing import Any, ClassVar, Mapping, Optional, Set, Tuple

import torch

Expand All @@ -25,14 +25,27 @@ def __init__(
self,
triples_factory: TriplesFactory,
num_negs_per_pos: Optional[int] = None,
filtered: bool = False,
corruption_scheme: Set[str] = None,
) -> None:
"""Initialize the negative sampler with the given entities.

:param triples_factory: The factory holding the triples to sample from
:param num_negs_per_pos: Number of negative samples to make per positive triple. Defaults to 1.
:param filtered: Whether proposed corrupted triples that are in the training data should be filtered.
Defaults to False.
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
:param corruption_scheme: What sides ('h', 'r', 't') should be corrupted. Defaults to head and tail ('h', 't').
"""
self.triples_factory = triples_factory
self.num_negs_per_pos = num_negs_per_pos if num_negs_per_pos is not None else 1
self.filtered = filtered
self.corruption_scheme = corruption_scheme or ('h', 't')
# Set the indices
self._corruption_indices = [0 if side == 'h' else 1 if side == 'r' else 2 for side in self.corruption_scheme]
# Copy the mapped triples to the device for efficient filtering
if filtered:
self.mapped_triples = self.triples_factory.mapped_triples
mberr marked this conversation as resolved.
Show resolved Hide resolved
self._filter_init = False

@classmethod
def get_normalized_name(cls) -> str:
Expand All @@ -44,7 +57,39 @@ def num_entities(self) -> int: # noqa: D401
"""The number of entities to sample from."""
return self.triples_factory.num_entities

@property
def num_relations(self) -> int: # noqa: D401
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
"""The number of relations to sample from."""
return self.triples_factory.num_relations

@abstractmethod
def sample(self, positive_batch: torch.LongTensor) -> torch.LongTensor:
def sample(self, positive_batch: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.Tensor]]:
"""Generate negative samples from the positive batch."""
raise NotImplementedError

def _filter_negative_triples(self, negative_batch: torch.LongTensor) -> torch.Tensor:
"""Filter all proposed negative samples that are positive in the training dataset.

Normally there is a low probability that proposed negative samples are positive in the training datasets and
thus act as false negatives. This is expected to act as a kind of regularization, since it adds noise signal to
the training data. However, the degree of regularization is hard to control since the added noise signal depends
on the ratio of true triples for a given entity relation or entity entity pair. Therefore, the effects are hard
to control and a researcher might want to exclude the possibility of having false negatives in the proposed
negative triples.
"""
# Make sure the mapped triples are on the right device
if not self._filter_init:
self.mapped_triples = self.mapped_triples.to(negative_batch.device)
self._filter_init = True
# Check which heads of the mapped triples are also in the negative triples
head_filter = (self.mapped_triples[:, 0:1].view(1, -1) == negative_batch[:, 0:1]).max(axis=0)[0]
# Reduce the search space by only using possible matches that at least contain the head we look for
sub_mapped_triples = self.mapped_triples[head_filter]
# Check in this subspace which relations of the mapped triples are also in the negative triples
relation_filter = (sub_mapped_triples[:, 1:2].view(1, -1) == negative_batch[:, 1:2]).max(axis=0)[0]
# Reduce the search space by only using possible matches that at least contain head and relation we look for
sub_mapped_triples = sub_mapped_triples[relation_filter]
# Create a filter indicating which of the proposed negative triples are positive in the training dataset
final_filter = (sub_mapped_triples[:, 2:3].view(1, -1) == negative_batch[:, 2:3]).max(axis=1)[0]
# Return only those proposed negative triples that are not positive in the training dataset
return ~final_filter
9 changes: 8 additions & 1 deletion src/pykeen/training/slcwa.py
Expand Up @@ -93,7 +93,7 @@ def _process_batch(
positive_batch = batch[start:stop].to(device=self.device)

# Create negative samples
neg_samples = self.negative_sampler.sample(positive_batch=positive_batch)
neg_samples, neg_samples_filter = self.negative_sampler.sample(positive_batch=positive_batch)
cthoyt marked this conversation as resolved.
Show resolved Hide resolved

# Ensure they reside on the device (should hold already for most simple negative samplers, e.g.
# BasicNegativeSampler, BernoulliNegativeSampler
Expand All @@ -110,6 +110,7 @@ def _process_batch(
positive_scores,
negative_scores,
label_smoothing,
neg_samples_filter,
)
return loss

Expand All @@ -118,11 +119,15 @@ def _mr_loss_helper(
positive_scores: torch.FloatTensor,
negative_scores: torch.FloatTensor,
_label_smoothing=None,
_batch_filter=None,
) -> torch.FloatTensor:
# Repeat positives scores (necessary for more than one negative per positive)
if self.num_negs_per_pos > 1:
positive_scores = positive_scores.repeat(self.num_negs_per_pos, 1)

if _batch_filter is not None:
positive_scores = positive_scores[_batch_filter]

return self.model.compute_mr_loss(
positive_scores=positive_scores,
negative_scores=negative_scores,
Expand All @@ -133,6 +138,7 @@ def _self_adversarial_negative_sampling_loss_helper(
positive_scores: torch.FloatTensor,
negative_scores: torch.FloatTensor,
_label_smoothing=None,
_batch_filter=None,
) -> torch.FloatTensor:
"""Compute self adversarial negative sampling loss."""
return self.model.compute_self_adversarial_negative_sampling_loss(
Expand All @@ -145,6 +151,7 @@ def _label_loss_helper(
positive_scores: torch.FloatTensor,
negative_scores: torch.FloatTensor,
label_smoothing: float,
_batch_filter=None,
) -> torch.FloatTensor:
# Stack predictions
predictions = torch.cat([positive_scores, negative_scores], dim=0)
Expand Down