Skip to content

Commit

Permalink
Add Particle Filter for sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 29, 2024
1 parent c548a8a commit b33e645
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 33 deletions.
2 changes: 1 addition & 1 deletion outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def sequence_generator(
allowed_tokens = get_allowed_tokens(fsms, fsm_states)
biased_logits = bias_logits(logits, allowed_tokens)
next_token_ids, ancestors, sequence_weights = sampler(
biased_logits, sequence_weights, rng
logits, biased_logits, sequence_weights, rng
)

token_ids = update_token_ids(token_ids, next_token_ids, ancestors)
Expand Down
195 changes: 171 additions & 24 deletions outlines/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class Sampler(Protocol):

def __call__(
self,
next_token_logits: torch.DoubleTensor,
logits: torch.DoubleTensor,
biased_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
rng: torch.Generator,
) -> torch.DoubleTensor:
Expand Down Expand Up @@ -38,17 +39,21 @@ def __init__(self):

def __call__(
self,
next_token_logits: torch.DoubleTensor,
logits: torch.DoubleTensor,
biased_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
_,
) -> torch.DoubleTensor:
"""Call the greedy sampler.
Parameters
----------
next_token_logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
biased probability distribution of the next token over the vocabulary.
sequence_weights
A tensor of shape ``(n_seqs,)`` that represents the cumulative
weight of each sequence.
Expand All @@ -63,12 +68,10 @@ def __call__(
cumulative weights of each sequence of shape ``(n_seqs,)``.
"""
logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
logprobs = torch.nn.functional.log_softmax(biased_logits, dim=-1)
next_token_ids = torch.argmax(logprobs, dim=-1, keepdim=True)

ancestors = torch.arange(
next_token_logits.shape[0], device=next_token_logits.device
)
ancestors = torch.arange(biased_logits.shape[0], device=biased_logits.device)
weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze()

return next_token_ids, ancestors, weights
Expand Down Expand Up @@ -113,17 +116,21 @@ def __init__(

def __call__(
self,
next_token_logits: torch.DoubleTensor,
logits: torch.DoubleTensor,
biased_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
rng: torch.Generator,
) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]:
"""Call the multinomial sampler.
Parameters
----------
next_token_logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
biased_logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
biased probability distribution of the next token over the vocabulary.
sequence_weights
A tensor of shape ``(n_seqs,)`` that represents the cumulative
weight of each sequence.
Expand All @@ -138,16 +145,16 @@ def __call__(
cumulative weights of each sequence of shape ``(n_seqs,)``.
"""
altered_next_token_logits = next_token_logits
altered_biased_logits = biased_logits
for logit_processor in self.logits_processors:
altered_next_token_logits = logit_processor(next_token_logits)
altered_biased_logits = logit_processor(biased_logits)

probs = torch.nn.functional.softmax(altered_next_token_logits, dim=-1)
probs = torch.nn.functional.softmax(altered_biased_logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng)

logprobs = torch.nn.functional.log_softmax(altered_next_token_logits, dim=-1)
logprobs = torch.nn.functional.log_softmax(altered_biased_logits, dim=-1)
ancestors = torch.arange(
altered_next_token_logits.shape[0], device=next_token_logits.device
altered_biased_logits.shape[0], device=biased_logits.device
)
weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze()

Expand Down Expand Up @@ -247,17 +254,21 @@ def __init__(self, beams: int = 1):

def __call__(
self,
next_token_logits: torch.DoubleTensor,
logits: torch.DoubleTensor,
biased_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
_,
) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]:
"""Call the beam search sampler.
Parameters
----------
next_token_logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
biased_logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
biased probability distribution of the next token over the vocabulary.
sequence_weights
A tensor of shape ``(n_seqs,)`` that represents the cumulative
weight of each sequence.
Expand All @@ -272,13 +283,13 @@ def __call__(
cumulative weights of each sequence of shape ``(n_seqs,)``.
"""
logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
weights = logprobs + sequence_weights.unsqueeze(1).expand_as(next_token_logits)
logprobs = torch.nn.functional.log_softmax(biased_logits, dim=-1)
weights = logprobs + sequence_weights.unsqueeze(1).expand_as(biased_logits)

# Flatten scores to (n_batch, n_samples * vocab_size)
# and find the top-k weights for each batch.
batch_size = next_token_logits.shape[0] // self.samples
vocab_size = next_token_logits.shape[-1]
batch_size = biased_logits.shape[0] // self.samples
vocab_size = biased_logits.shape[-1]
weights = weights.view(batch_size, self.samples * vocab_size)

# If the weights are all equal to 0 we are at the beginning of the search
Expand All @@ -296,7 +307,7 @@ def __call__(

# Re-shape the weights, next_token_ids and ancestors to (n_batch * n_samples, 1)
first_batch_idx = torch.arange(
0, batch_size * self.samples, self.samples, device=next_token_logits.device
0, batch_size * self.samples, self.samples, device=biased_logits.device
).unsqueeze(1)
ancestors = ancestors + first_batch_idx

Expand All @@ -308,3 +319,139 @@ def __call__(


beam_search = BeamSearchSampler


def multinomial_resampling(
rng: torch.Generator, weights: torch.Tensor, num_samples: int
):
"""Standard multinomial resampling for Sequential Monte Carlo.
This resampling function has very high variance.
Parameters
----------
rng
`torch.Generator` instance to use for resampling the particles
weights
The weights of the particles. Shape (n_batch, n_particles * vocab_size)
num_samples
The number of particles to sample
Returns
-------
The ids of the particles to keep.
"""

probs = torch.exp(weights)
indices = torch.multinomial(probs, num_samples=num_samples, generator=rng)
weights = torch.gather(weights, 1, indices)

return weights, indices


class ParticleFilter:
"""Particle Filtering algorithm.
This sampling algorithm is similar to Beam Search, except we use a
non-deterministic resampling function to downsample instead of only keeping
particles with the largest probability. In the simplest case the downsampling
function is multinomial sampling. You may want to use other resampling functions
as multinomial resampling is known to lead to very large variance.
Since we mask the logits before sampling we need to apply a correction to the
log-probability before sampling. Indeed, sampling from the distribution
of sequences that respect the constraint corresponds to sampling from
the next-token distribution and then assigning 0 weight to the sequences
that violate the constraint. We can see the process implemented here as
using the distribution with masked tokens as a locally optimal proposal, and
we need to correct for this choice of proposal.
Particle filters often only resample particles when the variability of the
weights is too large. Here we resample at every step.
Note
----
This implementation could be merged with Beam Search by using `top-k` as
a resampling function.
Attributes
----------
samples
The number of samples taken for each input sequence.
ess_threshold
The Effective Sample Size threshold below which the particles are
resampled.
"""

def __init__(
self,
particles: int = 1,
ess_threshold=0.5,
resampling_fn=multinomial_resampling,
):
self.samples = particles
self.resampling_fn = resampling_fn

def __call__(
self,
logits: torch.DoubleTensor,
biased_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
rng: torch.Generator,
) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]:
"""Call the particle filter.
Parameters
----------
logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
biased_logits
a tensor of shape ``(n_seqs, vocab_size,)`` that represents the
biased probability distribution of the next token over the vocabulary.
sequence_weights
A tensor of shape ``(n_seqs,)`` that represents the cumulative
weight of each sequence.
rng
A random number generator.
Returns
-------
A tuple with an array that contains the ids of the sampled tokens of
shape ``(n_seqs, 1)``, an array that contains the ancestors of each
sampled id of shape ``(n_seqs,)`` and an array that contains the updated
cumulative weights of each sequence of shape ``(n_seqs,)``.
"""

original_logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
proposal_logprobs = torch.nn.functional.log_softmax(biased_logits, dim=-1)
weights = original_logprobs - proposal_logprobs
weights[weights == math.inf] = -math.inf # Cannot sample the masked logits

# Flatten weights to (n_batch, n_samples * vocab_size)
batch_size = biased_logits.shape[0] // self.samples
vocab_size = biased_logits.shape[-1]
weights = weights.view(batch_size, self.samples * vocab_size)

weights, indices = self.resampling_fn(rng, weights, self.samples)

ancestors = torch.div(indices, vocab_size, rounding_mode="floor")
next_token_ids = indices % vocab_size

# Re-shape the weights, next_token_ids and ancestors to (n_batch * n_samples, 1)
first_batch_idx = torch.arange(
0, batch_size * self.samples, self.samples, device=biased_logits.device
).unsqueeze(1)
ancestors = ancestors + first_batch_idx

ancestors = ancestors.view(self.samples * batch_size)
weights = weights.view(self.samples * batch_size)
next_token_ids = next_token_ids.view(self.samples * batch_size, 1)

return next_token_ids, ancestors, weights


smc = ParticleFilter
Loading

0 comments on commit b33e645

Please sign in to comment.