Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disables token sampling when temperature set to 0 #200

Closed
wants to merge 10 commits into from
2 changes: 1 addition & 1 deletion clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def valid_seed(cls, v):

@validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
if v is not None and v < 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: please change validation error to read "temperature must be non-negative".

raise ValidationError("`temperature` must be strictly positive")
return v

Expand Down
4 changes: 3 additions & 1 deletion docs/reference/python_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ class Parameters:
adapter_source: Optional[str]
# API token for accessing private adapters
api_token: Optional[str]
# Activate logits sampling
# Activate logits sampling.
# Ignored if temperature parameter = 0 (determinstic token choosing)
do_sample: bool
# Maximum number of generated tokens
max_new_tokens: int
Expand All @@ -111,6 +112,7 @@ class Parameters:
# Random sampling seed
seed: Optional[int]
# The value used to module the logits distribution.
# Setting value to 0 invokes deterministic token choosing
temperature: Optional[float]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k: Optional[int]
Expand Down
2 changes: 1 addition & 1 deletion router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ impl Validation {
}

let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 {
if temperature < 0.0 {
return Err(ValidationError::Temperature);
}

Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/utils/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
):
self.warpers = []

if temperature is not None and temperature != 1.0:
if temperature is not None and temperature != 1.0 and temperature != 0:
temperature = float(temperature)
self.warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
Expand Down
19 changes: 17 additions & 2 deletions server/lorax_server/utils/tokens.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import torch
import warnings

from transformers import (
RepetitionPenaltyLogitsProcessor,
Expand Down Expand Up @@ -62,8 +63,10 @@ def __init__(
else None
)

# Temperature = 1 does not change logits; do not use warper
# Temperature = 0 invokes determinstic token choosing; do not warp
has_warpers = (
(temperature is not None and temperature != 1.0)
(temperature is not None and temperature != 1.0 and temperature != 0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0)
Expand All @@ -76,6 +79,13 @@ def __init__(
self.static_warper = None

sampling = do_sample or has_warpers

# do not sample if temperature is 0, even if do_sample flag is set True
# warn user about deterministic sampling
if sampling and temperature == 0:
sampling = False
warnings.warn("Temperature is set to 0, token sampling will be disabled")

self.choice = Sampling(seed, device) if sampling else Greedy()

def __call__(self, input_ids, scores):
Expand Down Expand Up @@ -253,8 +263,11 @@ def __init__(
)

if any([x != 1.0 for x in temperature]):
# set sample flags for each index
# do not sample this index if temperature is 0 or 1
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
sample or (x != 1.0 and x != 0)
for x, sample in zip(temperature, do_sample)
]
warpers.append(
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
Expand All @@ -274,8 +287,10 @@ def __init__(

self.warpers = warpers

# sample tokens from distribution if any sample flags are set True
if any(do_sample):
self.choice = HeterogeneousSampling(do_sample, seeds, device)
# all tokens are set false, do Greedy / deterministic sampling
else:
self.choice = Greedy()

Expand Down
132 changes: 132 additions & 0 deletions server/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import pytest

from transformers import AutoTokenizer

from lorax_server.pb import generate_pb2
from lorax_server.models.causal_lm import CausalLM, CausalLMBatch
from lorax_server.utils.lora import AdapterBatchData
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.utils.tokens import NextTokenChooser


@pytest.fixture
Expand All @@ -18,3 +24,129 @@ def default_pb_parameters():
@pytest.fixture
def default_pb_stop_parameters():
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)


@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM("gpt2")


@pytest.fixture(scope="session")
def gpt2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token_id = 50256
return tokenizer


@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)


@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)


@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
return CausalLMBatch.from_pb(
default_pb_batch,
gpt2_tokenizer,
TokenizerManager(),
torch.float32,
torch.device("cpu"),
)


# check generations work normally with temperature = 0
def test_generate_token_temperature_zero(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
batch = default_causal_lm_batch

# set all token choosers in batch to be deterministic with Temperature = 0
determ_token_choosers = [
NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers))
]
batch.next_token_choosers = determ_token_choosers
# generate tokens from next batch
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)

# same assertions as testing generate token, causal lm
assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch)

assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1


# generates tokens with determinstic choosers,
# checks that output tokens have highest probability in distribution
def test_deterministic_tokens_temperature_zero(
default_causal_lm, default_causal_lm_batch
):
# Inside of CausalLM.generate_token, used to access
# logit distribution and compare log prob
batch = default_causal_lm_batch

# set all token choosers in batch to be deterministic with Temperature = 0
determ_token_choosers = [
NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers))
]
batch.next_token_choosers = determ_token_choosers

attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

adapter_data = AdapterBatchData.from_meta(
batch.adapter_meta, default_causal_lm.batched_lora_weights
)

logits, _ = default_causal_lm.forward(
batch.input_ids,
attention_mask,
batch.position_ids,
batch.past_key_values,
adapter_data,
)

# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)

# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]
)

# Generated token
next_token_logprob = logprobs[-1, next_token_id]

# A deterministic model with Temperature = 0 should always choose
# the highest logprob token
assert next_token_logprob == max(logprobs[-1])
3 changes: 3 additions & 0 deletions server/tests/utils/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from transformers import AutoTokenizer


from lorax_server.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
Expand Down