From e295dbb23de1e989f1550cc42d4df4d671b3fb3a Mon Sep 17 00:00:00 2001 From: Adam Kelch Date: Mon, 22 Jan 2024 11:39:48 -0500 Subject: [PATCH 01/10] disable token sampling when temperature is 0 --- server/lorax_server/utils/tokens.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 1d4a483d3..78c18fa35 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -75,7 +75,8 @@ def __init__( else: self.static_warper = None - sampling = do_sample or has_warpers + # sample based on flags and if temperature isn't 0 + sampling = (do_sample or has_warpers) and temperature != 0 self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): @@ -253,8 +254,11 @@ def __init__( ) if any([x != 1.0 for x in temperature]): + # set sample flags for each index + # do not sample this index if temperature is 0 do_sample = [ - sample or x != 1.0 for x, sample in zip(temperature, do_sample) + (sample or x != 1.0) and x != 0 + for x, sample in zip(temperature, do_sample) ] warpers.append( HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) @@ -274,8 +278,10 @@ def __init__( self.warpers = warpers + # sample tokens from distribution if any sample flags are set True if any(do_sample): self.choice = HeterogeneousSampling(do_sample, seeds, device) + # all tokens are set false, do Greedy / determinsitc sampling else: self.choice = Greedy() From 18079e62ff4fe53f00175f0022050635f2fb8bfb Mon Sep 17 00:00:00 2001 From: Adam Kelch Date: Mon, 22 Jan 2024 16:12:16 -0500 Subject: [PATCH 02/10] adjust sample flag setting for token choosers --- server/lorax_server/utils/tokens.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 78c18fa35..b565cbffd 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -1,5 +1,6 @@ import re import torch +import warnings from transformers import ( RepetitionPenaltyLogitsProcessor, @@ -62,8 +63,10 @@ def __init__( else None ) + # Temperature = 1 does not change logits; do not use warper + # Temperature = 0 invokes determinstic token choosing; do not warp has_warpers = ( - (temperature is not None and temperature != 1.0) + (temperature is not None and temperature != 1.0 and temperature != 0) or (top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (typical_p is not None and typical_p < 1.0) @@ -75,8 +78,14 @@ def __init__( else: self.static_warper = None - # sample based on flags and if temperature isn't 0 - sampling = (do_sample or has_warpers) and temperature != 0 + sampling = do_sample or has_warpers + + # do not sample if temperature is 0, even if do_sample flag is set True + # warn user about determinstic sampling + if sampling and temperature == 0: + sampling = False + warnings.warn("Temperature is set to 0, token sampling will be disabled") + self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): @@ -255,9 +264,9 @@ def __init__( if any([x != 1.0 for x in temperature]): # set sample flags for each index - # do not sample this index if temperature is 0 + # do not sample this index if temperature is 0 or 1 do_sample = [ - (sample or x != 1.0) and x != 0 + sample or (x != 1.0 and x != 0) for x, sample in zip(temperature, do_sample) ] warpers.append( From 25227bb8e7ea4a1099016ae27d9434bb096c02aa Mon Sep 17 00:00:00 2001 From: Adam Kelch Date: Mon, 22 Jan 2024 16:31:46 -0500 Subject: [PATCH 03/10] temperature validation changes --- router/src/validation.rs | 2 +- server/lorax_server/utils/logits_process.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/router/src/validation.rs b/router/src/validation.rs index 8f985cdcc..de95b073a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -162,7 +162,7 @@ impl Validation { } let temperature = temperature.unwrap_or(1.0); - if temperature <= 0.0 { + if temperature < 0.0 { return Err(ValidationError::Temperature); } diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index f424eae40..7387e808a 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -26,7 +26,7 @@ def __init__( ): self.warpers = [] - if temperature is not None and temperature != 1.0: + if temperature is not None and temperature != 1.0 and temperature != 0: temperature = float(temperature) self.warpers.append(TemperatureLogitsWarper(temperature)) if top_k is not None and top_k != 0: From 027bf3099d7d5bb882423653587e4410893d0fb1 Mon Sep 17 00:00:00 2001 From: Adam Kelch Date: Mon, 22 Jan 2024 16:36:14 -0500 Subject: [PATCH 04/10] fix comments --- server/lorax_server/utils/tokens.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index b565cbffd..db6589193 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -81,7 +81,7 @@ def __init__( sampling = do_sample or has_warpers # do not sample if temperature is 0, even if do_sample flag is set True - # warn user about determinstic sampling + # warn user about deterministic sampling if sampling and temperature == 0: sampling = False warnings.warn("Temperature is set to 0, token sampling will be disabled") @@ -290,7 +290,7 @@ def __init__( # sample tokens from distribution if any sample flags are set True if any(do_sample): self.choice = HeterogeneousSampling(do_sample, seeds, device) - # all tokens are set false, do Greedy / determinsitc sampling + # all tokens are set false, do Greedy / deterministic sampling else: self.choice = Greedy() From 4258d3f6cf398162283b0fce0e65d2401379d4a8 Mon Sep 17 00:00:00 2001 From: Adam Kelch Date: Mon, 22 Jan 2024 21:23:19 -0500 Subject: [PATCH 05/10] add test for deterministic token choosing when temp = 0 --- server/tests/utils/test_tokens.py | 69 +++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index e68956ffe..634a679b3 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -2,8 +2,16 @@ StopSequenceCriteria, StoppingCriteria, FinishReason, + NextTokenChooser, ) +from lorax_server.utils.lora import AdapterBatchData + +from tests.models.test_model import get_test_model +from lorax_server.pb import generate_pb2 +from lorax_server.models.causal_lm import CausalLM, CausalLMBatch +from tests.models.test_causal_lm import default_causal_lm, default_causal_lm_batch + def test_stop_sequence_criteria(): criteria = StopSequenceCriteria("/test;") @@ -42,3 +50,64 @@ def test_stopping_criteria_max(): assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) + + +def test_deterministic_tokens_temperature_zero(): + # Inside of CausalLM.generate_token, used to access + # logit distribution and compare log prob + batch = default_causal_lm_batch + + # set all token choosers in batch to be deterministic with Temperature = 0 + determ_token_choosers = [ + NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers)) + ] + batch.next_token_choosers = determ_token_choosers + + attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + + adapter_data = AdapterBatchData.from_meta( + batch.adapter_meta, default_causal_lm.batched_lora_weights + ) + + logits, _ = default_causal_lm.forward( + batch.input_ids, + attention_mask, + batch.position_ids, + batch.past_key_values, + adapter_data, + ) + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + + # A deterministic model with Temperature = 0 should always choose + # the highest logprob token + assert next_token_logprob == max(logprobs[-1]) From 24da64eeb7f3cd1180d9dd836a01146ec314b7fa Mon Sep 17 00:00:00 2001 From: Adam Kelch Date: Tue, 23 Jan 2024 10:18:55 -0500 Subject: [PATCH 06/10] add additional test --- server/tests/utils/test_tokens.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 634a679b3..06e1377dd 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -6,8 +6,6 @@ ) from lorax_server.utils.lora import AdapterBatchData - -from tests.models.test_model import get_test_model from lorax_server.pb import generate_pb2 from lorax_server.models.causal_lm import CausalLM, CausalLMBatch from tests.models.test_causal_lm import default_causal_lm, default_causal_lm_batch @@ -52,6 +50,30 @@ def test_stopping_criteria_max(): assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) +# check generations work normally with temperature = 0 +def test_generate_token_temperature_zero(): + sequence_length = len(default_causal_lm_batch.all_input_ids[0]) + batch = default_causal_lm_batch + + # set all token choosers in batch to be deterministic with Temperature = 0 + determ_token_choosers = [ + NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers)) + ] + batch.next_token_choosers = determ_token_choosers + # generate tokens from next batch + generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) + + # same assertions as testing generate token, causal llm + assert len(generations) == len(next_batch) + assert len(generations) == len(next_batch) + assert isinstance(next_batch, CausalLMBatch) + + assert len(next_batch.all_input_ids) == len(next_batch) + assert len(next_batch.all_input_ids[0]) == sequence_length + 1 + + +# generates tokens with determinstic choosers, +# checks that output tokens have highest probability in distribution def test_deterministic_tokens_temperature_zero(): # Inside of CausalLM.generate_token, used to access # logit distribution and compare log prob From f3b912b86e409b85439ee4dd385ee2cde87caf5f Mon Sep 17 00:00:00 2001 From: Adam Kelch Date: Tue, 23 Jan 2024 10:28:05 -0500 Subject: [PATCH 07/10] clarify affects of 0 temperature parameter in docs --- docs/reference/python_client.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/reference/python_client.md b/docs/reference/python_client.md index 9e238d5e9..85143cb2f 100644 --- a/docs/reference/python_client.md +++ b/docs/reference/python_client.md @@ -97,7 +97,8 @@ class Parameters: adapter_source: Optional[str] # API token for accessing private adapters api_token: Optional[str] - # Activate logits sampling + # Activate logits sampling. + # Ignored if temperature parameter = 0 (determinstic token choosing) do_sample: bool # Maximum number of generated tokens max_new_tokens: int @@ -111,6 +112,7 @@ class Parameters: # Random sampling seed seed: Optional[int] # The value used to module the logits distribution. + # Setting value to 0 invokes deterministic token choosing temperature: Optional[float] # The number of highest probability vocabulary tokens to keep for top-k-filtering. top_k: Optional[int] From b7f2f922e29cddd3b79bf06a9ef88c3f47a6230f Mon Sep 17 00:00:00 2001 From: Adam Kelch Date: Tue, 23 Jan 2024 12:55:33 -0500 Subject: [PATCH 08/10] change temperature validation to include 0 --- clients/python/lorax/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index fe880f5f5..7c2f1265c 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -88,7 +88,7 @@ def valid_seed(cls, v): @validator("temperature") def valid_temp(cls, v): - if v is not None and v <= 0: + if v is not None and v < 0: raise ValidationError("`temperature` must be strictly positive") return v From c5049ac30dea70e49a2c2c501e499164be434db6 Mon Sep 17 00:00:00 2001 From: Adam Kelch Date: Tue, 23 Jan 2024 13:12:27 -0500 Subject: [PATCH 09/10] move tests to conftest.py, refactor --- server/tests/conftest.py | 132 ++++++++++++++++++++++++++++++ server/tests/utils/test_tokens.py | 94 +-------------------- 2 files changed, 135 insertions(+), 91 deletions(-) diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 7d77b6fe0..ee91365fc 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,6 +1,12 @@ import pytest +from transformers import AutoTokenizer + from lorax_server.pb import generate_pb2 +from lorax_server.models.causal_lm import CausalLM, CausalLMBatch +from lorax_server.utils.lora import AdapterBatchData +from lorax_server.utils.tokenizer import TokenizerManager +from lorax_server.utils.tokens import NextTokenChooser @pytest.fixture @@ -18,3 +24,129 @@ def default_pb_parameters(): @pytest.fixture def default_pb_stop_parameters(): return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10) + + +@pytest.fixture(scope="session") +def default_causal_lm(): + return CausalLM("gpt2") + + +@pytest.fixture(scope="session") +def gpt2_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") + tokenizer.pad_token_id = 50256 + return tokenizer + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): + return CausalLMBatch.from_pb( + default_pb_batch, + gpt2_tokenizer, + TokenizerManager(), + torch.float32, + torch.device("cpu"), + ) + + +# check generations work normally with temperature = 0 +def test_generate_token_temperature_zero(default_causal_lm, default_causal_lm_batch): + sequence_length = len(default_causal_lm_batch.all_input_ids[0]) + batch = default_causal_lm_batch + + # set all token choosers in batch to be deterministic with Temperature = 0 + determ_token_choosers = [ + NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers)) + ] + batch.next_token_choosers = determ_token_choosers + # generate tokens from next batch + generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) + + # same assertions as testing generate token, causal lm + assert len(generations) == len(next_batch) + assert isinstance(next_batch, CausalLMBatch) + + assert len(next_batch.all_input_ids) == len(next_batch) + assert len(next_batch.all_input_ids[0]) == sequence_length + 1 + + +# generates tokens with determinstic choosers, +# checks that output tokens have highest probability in distribution +def test_deterministic_tokens_temperature_zero( + default_causal_lm, default_causal_lm_batch +): + # Inside of CausalLM.generate_token, used to access + # logit distribution and compare log prob + batch = default_causal_lm_batch + + # set all token choosers in batch to be deterministic with Temperature = 0 + determ_token_choosers = [ + NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers)) + ] + batch.next_token_choosers = determ_token_choosers + + attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + + adapter_data = AdapterBatchData.from_meta( + batch.adapter_meta, default_causal_lm.batched_lora_weights + ) + + logits, _ = default_causal_lm.forward( + batch.input_ids, + attention_mask, + batch.position_ids, + batch.past_key_values, + adapter_data, + ) + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + + # A deterministic model with Temperature = 0 should always choose + # the highest logprob token + assert next_token_logprob == max(logprobs[-1]) diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 06e1377dd..7af2617dc 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -1,15 +1,12 @@ +from transformers import AutoTokenizer + + from lorax_server.utils.tokens import ( StopSequenceCriteria, StoppingCriteria, FinishReason, - NextTokenChooser, ) -from lorax_server.utils.lora import AdapterBatchData -from lorax_server.pb import generate_pb2 -from lorax_server.models.causal_lm import CausalLM, CausalLMBatch -from tests.models.test_causal_lm import default_causal_lm, default_causal_lm_batch - def test_stop_sequence_criteria(): criteria = StopSequenceCriteria("/test;") @@ -48,88 +45,3 @@ def test_stopping_criteria_max(): assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) - - -# check generations work normally with temperature = 0 -def test_generate_token_temperature_zero(): - sequence_length = len(default_causal_lm_batch.all_input_ids[0]) - batch = default_causal_lm_batch - - # set all token choosers in batch to be deterministic with Temperature = 0 - determ_token_choosers = [ - NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers)) - ] - batch.next_token_choosers = determ_token_choosers - # generate tokens from next batch - generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) - - # same assertions as testing generate token, causal llm - assert len(generations) == len(next_batch) - assert len(generations) == len(next_batch) - assert isinstance(next_batch, CausalLMBatch) - - assert len(next_batch.all_input_ids) == len(next_batch) - assert len(next_batch.all_input_ids[0]) == sequence_length + 1 - - -# generates tokens with determinstic choosers, -# checks that output tokens have highest probability in distribution -def test_deterministic_tokens_temperature_zero(): - # Inside of CausalLM.generate_token, used to access - # logit distribution and compare log prob - batch = default_causal_lm_batch - - # set all token choosers in batch to be deterministic with Temperature = 0 - determ_token_choosers = [ - NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers)) - ] - batch.next_token_choosers = determ_token_choosers - - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - - adapter_data = AdapterBatchData.from_meta( - batch.adapter_meta, default_causal_lm.batched_lora_weights - ) - - logits, _ = default_causal_lm.forward( - batch.input_ids, - attention_mask, - batch.position_ids, - batch.past_key_values, - adapter_data, - ) - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] - ) - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - - # A deterministic model with Temperature = 0 should always choose - # the highest logprob token - assert next_token_logprob == max(logprobs[-1]) From ed4845d71ebd3d7fb77cc23a485c94bd64ac6936 Mon Sep 17 00:00:00 2001 From: Adam Kelch Date: Tue, 23 Jan 2024 17:41:30 -0500 Subject: [PATCH 10/10] update temperature validation --- clients/python/lorax/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 7c2f1265c..c70deb3b6 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -89,7 +89,7 @@ def valid_seed(cls, v): @validator("temperature") def valid_temp(cls, v): if v is not None and v < 0: - raise ValidationError("`temperature` must be strictly positive") + raise ValidationError("`temperature` must be non-negative") return v @validator("top_k")