Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Particle Filter #673

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def sequence_generator(
allowed_tokens = get_allowed_tokens(fsms, fsm_states)
biased_logits = bias_logits(logits, allowed_tokens)
next_token_ids, ancestors, sequence_weights = sampler(
biased_logits, sequence_weights, rng
logits, biased_logits, sequence_weights, rng
)

token_ids = update_token_ids(token_ids, next_token_ids, ancestors)
Expand Down
195 changes: 171 additions & 24 deletions outlines/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class Sampler(Protocol):

def __call__(
self,
next_token_logits: torch.DoubleTensor,
logits: torch.DoubleTensor,
biased_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
rng: torch.Generator,
) -> torch.DoubleTensor:
Expand Down Expand Up @@ -38,17 +39,21 @@ def __init__(self):

def __call__(
self,
next_token_logits: torch.DoubleTensor,
logits: torch.DoubleTensor,
biased_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
_,
) -> torch.DoubleTensor:
"""Call the greedy sampler.

Parameters
----------
next_token_logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
biased probability distribution of the next token over the vocabulary.
sequence_weights
A tensor of shape ``(n_seqs,)`` that represents the cumulative
weight of each sequence.
Expand All @@ -63,12 +68,10 @@ def __call__(
cumulative weights of each sequence of shape ``(n_seqs,)``.

"""
logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
logprobs = torch.nn.functional.log_softmax(biased_logits, dim=-1)
next_token_ids = torch.argmax(logprobs, dim=-1, keepdim=True)

ancestors = torch.arange(
next_token_logits.shape[0], device=next_token_logits.device
)
ancestors = torch.arange(biased_logits.shape[0], device=biased_logits.device)
weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze()

return next_token_ids, ancestors, weights
Expand Down Expand Up @@ -113,17 +116,21 @@ def __init__(

def __call__(
self,
next_token_logits: torch.DoubleTensor,
logits: torch.DoubleTensor,
biased_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
rng: torch.Generator,
) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]:
"""Call the multinomial sampler.

Parameters
----------
next_token_logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
biased_logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
biased probability distribution of the next token over the vocabulary.
sequence_weights
A tensor of shape ``(n_seqs,)`` that represents the cumulative
weight of each sequence.
Expand All @@ -138,16 +145,16 @@ def __call__(
cumulative weights of each sequence of shape ``(n_seqs,)``.

"""
altered_next_token_logits = next_token_logits
altered_biased_logits = biased_logits
for logit_processor in self.logits_processors:
altered_next_token_logits = logit_processor(next_token_logits)
altered_biased_logits = logit_processor(biased_logits)

probs = torch.nn.functional.softmax(altered_next_token_logits, dim=-1)
probs = torch.nn.functional.softmax(altered_biased_logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng)

logprobs = torch.nn.functional.log_softmax(altered_next_token_logits, dim=-1)
logprobs = torch.nn.functional.log_softmax(altered_biased_logits, dim=-1)
ancestors = torch.arange(
altered_next_token_logits.shape[0], device=next_token_logits.device
altered_biased_logits.shape[0], device=biased_logits.device
)
weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze()

Expand Down Expand Up @@ -247,17 +254,21 @@ def __init__(self, beams: int = 1):

def __call__(
self,
next_token_logits: torch.DoubleTensor,
logits: torch.DoubleTensor,
biased_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
_,
) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]:
"""Call the beam search sampler.

Parameters
----------
next_token_logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
biased_logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
biased probability distribution of the next token over the vocabulary.
sequence_weights
A tensor of shape ``(n_seqs,)`` that represents the cumulative
weight of each sequence.
Expand All @@ -272,13 +283,13 @@ def __call__(
cumulative weights of each sequence of shape ``(n_seqs,)``.

"""
logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
weights = logprobs + sequence_weights.unsqueeze(1).expand_as(next_token_logits)
logprobs = torch.nn.functional.log_softmax(biased_logits, dim=-1)
weights = logprobs + sequence_weights.unsqueeze(1).expand_as(biased_logits)

# Flatten scores to (n_batch, n_samples * vocab_size)
# and find the top-k weights for each batch.
batch_size = next_token_logits.shape[0] // self.samples
vocab_size = next_token_logits.shape[-1]
batch_size = biased_logits.shape[0] // self.samples
vocab_size = biased_logits.shape[-1]
weights = weights.view(batch_size, self.samples * vocab_size)

# If the weights are all equal to 0 we are at the beginning of the search
Expand All @@ -296,7 +307,7 @@ def __call__(

# Re-shape the weights, next_token_ids and ancestors to (n_batch * n_samples, 1)
first_batch_idx = torch.arange(
0, batch_size * self.samples, self.samples, device=next_token_logits.device
0, batch_size * self.samples, self.samples, device=biased_logits.device
).unsqueeze(1)
ancestors = ancestors + first_batch_idx

Expand All @@ -308,3 +319,139 @@ def __call__(


beam_search = BeamSearchSampler


def multinomial_resampling(
rng: torch.Generator, weights: torch.Tensor, num_samples: int
):
"""Standard multinomial resampling for Sequential Monte Carlo.

This resampling function has very high variance.

Parameters
----------
rng
`torch.Generator` instance to use for resampling the particles
weights
The weights of the particles. Shape (n_batch, n_particles * vocab_size)
num_samples
The number of particles to sample

Returns
-------
The ids of the particles to keep.

"""

probs = torch.exp(weights)
indices = torch.multinomial(probs, num_samples=num_samples, generator=rng)
weights = torch.gather(weights, 1, indices)

return weights, indices


class ParticleFilter:
"""Particle Filtering algorithm.

This sampling algorithm is similar to Beam Search, except we use a
non-deterministic resampling function to downsample instead of only keeping
particles with the largest probability. In the simplest case the downsampling
function is multinomial sampling. You may want to use other resampling functions
as multinomial resampling is known to lead to very large variance.

Since we mask the logits before sampling we need to apply a correction to the
log-probability before sampling. Indeed, sampling from the distribution
of sequences that respect the constraint corresponds to sampling from
the next-token distribution and then assigning 0 weight to the sequences
that violate the constraint. We can see the process implemented here as
using the distribution with masked tokens as a locally optimal proposal, and
we need to correct for this choice of proposal.

Particle filters often only resample particles when the variability of the
weights is too large. Here we resample at every step.

Note
----
This implementation could be merged with Beam Search by using `top-k` as
a resampling function.

Attributes
----------
samples
The number of samples taken for each input sequence.
ess_threshold
The Effective Sample Size threshold below which the particles are
resampled.

"""

def __init__(
self,
particles: int = 1,
ess_threshold=0.5,
resampling_fn=multinomial_resampling,
):
self.samples = particles
self.resampling_fn = resampling_fn

def __call__(
self,
logits: torch.DoubleTensor,
biased_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
rng: torch.Generator,
) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]:
"""Call the particle filter.

Parameters
----------
logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
biased_logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
biased probability distribution of the next token over the vocabulary.
sequence_weights
A tensor of shape ``(n_seqs,)`` that represents the cumulative
weight of each sequence.
rng
A random number generator.

Returns
-------
A tuple with an array that contains the ids of the sampled tokens of
shape ``(n_seqs, 1)``, an array that contains the ancestors of each
sampled id of shape ``(n_seqs,)`` and an array that contains the updated
cumulative weights of each sequence of shape ``(n_seqs,)``.

"""

original_logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
proposal_logprobs = torch.nn.functional.log_softmax(biased_logits, dim=-1)
weights = original_logprobs - proposal_logprobs
weights[weights == math.inf] = -math.inf # Cannot sample the masked logits

# Flatten weights to (n_batch, n_samples * vocab_size)
batch_size = biased_logits.shape[0] // self.samples
vocab_size = biased_logits.shape[-1]
weights = weights.view(batch_size, self.samples * vocab_size)

weights, indices = self.resampling_fn(rng, weights, self.samples)

ancestors = torch.div(indices, vocab_size, rounding_mode="floor")
next_token_ids = indices % vocab_size

# Re-shape the weights, next_token_ids and ancestors to (n_batch * n_samples, 1)
first_batch_idx = torch.arange(
0, batch_size * self.samples, self.samples, device=biased_logits.device
).unsqueeze(1)
ancestors = ancestors + first_batch_idx

ancestors = ancestors.view(self.samples * batch_size)
weights = weights.view(self.samples * batch_size)
next_token_ids = next_token_ids.view(self.samples * batch_size, 1)

return next_token_ids, ancestors, weights


smc = ParticleFilter
Loading
Loading