diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index fe880f5f..c70deb3b 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -88,8 +88,8 @@ def valid_seed(cls, v): @validator("temperature") def valid_temp(cls, v): - if v is not None and v <= 0: - raise ValidationError("`temperature` must be strictly positive") + if v is not None and v < 0: + raise ValidationError("`temperature` must be non-negative") return v @validator("top_k") diff --git a/docs/reference/python_client.md b/docs/reference/python_client.md index 9e238d5e..85143cb2 100644 --- a/docs/reference/python_client.md +++ b/docs/reference/python_client.md @@ -97,7 +97,8 @@ class Parameters: adapter_source: Optional[str] # API token for accessing private adapters api_token: Optional[str] - # Activate logits sampling + # Activate logits sampling. + # Ignored if temperature parameter = 0 (determinstic token choosing) do_sample: bool # Maximum number of generated tokens max_new_tokens: int @@ -111,6 +112,7 @@ class Parameters: # Random sampling seed seed: Optional[int] # The value used to module the logits distribution. + # Setting value to 0 invokes deterministic token choosing temperature: Optional[float] # The number of highest probability vocabulary tokens to keep for top-k-filtering. top_k: Optional[int] diff --git a/router/src/validation.rs b/router/src/validation.rs index 8f985cdc..de95b073 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -162,7 +162,7 @@ impl Validation { } let temperature = temperature.unwrap_or(1.0); - if temperature <= 0.0 { + if temperature < 0.0 { return Err(ValidationError::Temperature); } diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index f424eae4..7387e808 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -26,7 +26,7 @@ def __init__( ): self.warpers = [] - if temperature is not None and temperature != 1.0: + if temperature is not None and temperature != 1.0 and temperature != 0: temperature = float(temperature) self.warpers.append(TemperatureLogitsWarper(temperature)) if top_k is not None and top_k != 0: diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 1d4a483d..db658919 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -1,5 +1,6 @@ import re import torch +import warnings from transformers import ( RepetitionPenaltyLogitsProcessor, @@ -62,8 +63,10 @@ def __init__( else None ) + # Temperature = 1 does not change logits; do not use warper + # Temperature = 0 invokes determinstic token choosing; do not warp has_warpers = ( - (temperature is not None and temperature != 1.0) + (temperature is not None and temperature != 1.0 and temperature != 0) or (top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (typical_p is not None and typical_p < 1.0) @@ -76,6 +79,13 @@ def __init__( self.static_warper = None sampling = do_sample or has_warpers + + # do not sample if temperature is 0, even if do_sample flag is set True + # warn user about deterministic sampling + if sampling and temperature == 0: + sampling = False + warnings.warn("Temperature is set to 0, token sampling will be disabled") + self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): @@ -253,8 +263,11 @@ def __init__( ) if any([x != 1.0 for x in temperature]): + # set sample flags for each index + # do not sample this index if temperature is 0 or 1 do_sample = [ - sample or x != 1.0 for x, sample in zip(temperature, do_sample) + sample or (x != 1.0 and x != 0) + for x, sample in zip(temperature, do_sample) ] warpers.append( HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) @@ -274,8 +287,10 @@ def __init__( self.warpers = warpers + # sample tokens from distribution if any sample flags are set True if any(do_sample): self.choice = HeterogeneousSampling(do_sample, seeds, device) + # all tokens are set false, do Greedy / deterministic sampling else: self.choice = Greedy() diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 7d77b6fe..ee91365f 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,6 +1,12 @@ import pytest +from transformers import AutoTokenizer + from lorax_server.pb import generate_pb2 +from lorax_server.models.causal_lm import CausalLM, CausalLMBatch +from lorax_server.utils.lora import AdapterBatchData +from lorax_server.utils.tokenizer import TokenizerManager +from lorax_server.utils.tokens import NextTokenChooser @pytest.fixture @@ -18,3 +24,129 @@ def default_pb_parameters(): @pytest.fixture def default_pb_stop_parameters(): return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10) + + +@pytest.fixture(scope="session") +def default_causal_lm(): + return CausalLM("gpt2") + + +@pytest.fixture(scope="session") +def gpt2_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") + tokenizer.pad_token_id = 50256 + return tokenizer + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): + return CausalLMBatch.from_pb( + default_pb_batch, + gpt2_tokenizer, + TokenizerManager(), + torch.float32, + torch.device("cpu"), + ) + + +# check generations work normally with temperature = 0 +def test_generate_token_temperature_zero(default_causal_lm, default_causal_lm_batch): + sequence_length = len(default_causal_lm_batch.all_input_ids[0]) + batch = default_causal_lm_batch + + # set all token choosers in batch to be deterministic with Temperature = 0 + determ_token_choosers = [ + NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers)) + ] + batch.next_token_choosers = determ_token_choosers + # generate tokens from next batch + generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) + + # same assertions as testing generate token, causal lm + assert len(generations) == len(next_batch) + assert isinstance(next_batch, CausalLMBatch) + + assert len(next_batch.all_input_ids) == len(next_batch) + assert len(next_batch.all_input_ids[0]) == sequence_length + 1 + + +# generates tokens with determinstic choosers, +# checks that output tokens have highest probability in distribution +def test_deterministic_tokens_temperature_zero( + default_causal_lm, default_causal_lm_batch +): + # Inside of CausalLM.generate_token, used to access + # logit distribution and compare log prob + batch = default_causal_lm_batch + + # set all token choosers in batch to be deterministic with Temperature = 0 + determ_token_choosers = [ + NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers)) + ] + batch.next_token_choosers = determ_token_choosers + + attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + + adapter_data = AdapterBatchData.from_meta( + batch.adapter_meta, default_causal_lm.batched_lora_weights + ) + + logits, _ = default_causal_lm.forward( + batch.input_ids, + attention_mask, + batch.position_ids, + batch.past_key_values, + adapter_data, + ) + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + + # A deterministic model with Temperature = 0 should always choose + # the highest logprob token + assert next_token_logprob == max(logprobs[-1]) diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index e68956ff..7af2617d 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -1,3 +1,6 @@ +from transformers import AutoTokenizer + + from lorax_server.utils.tokens import ( StopSequenceCriteria, StoppingCriteria,