Skip to content

Commit

Permalink
Merge pull request #3 from afeldman-nm/enc_dec_t5
Browse files Browse the repository at this point in the history
fix _make_tensor_with_pad args change which broke decoder scenarios
  • Loading branch information
js8544 committed Mar 6, 2024
2 parents 4bf056b + 9c03760 commit 9f20ccf
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 115 deletions.
34 changes: 0 additions & 34 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,6 @@ def forward(
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))

# print("query.shape: ", query.shape)
# print("key.shape: ", key.shape)
# print("value.shape: ", value.shape)
out = xops.memory_efficient_attention_forward(
query,
key,
Expand Down Expand Up @@ -292,7 +289,6 @@ def paged_attention(
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# print("max_num_partitions: ", max_num_partitions)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
Expand All @@ -303,25 +299,6 @@ def paged_attention(
use_v1 = max_context_len <= 8192 and (max_num_partitions == 1
or num_seqs * num_heads > 512)
if use_v1:
# print("v1")
# print("output: ", output)
# print("query: ", query)
# print("num_kv_heads: ", num_kv_heads)
# print("scale: ", scale)
# print("block_tables: ", block_tables)
# print("context_lens: ", context_lens)
# print("block_size: ", block_size)
# print("max_context_len: ", max_context_len)
# print("alibi_slopes: ", alibi_slopes)
# print("custom_bias: ", custom_bias)
# print("key_cache shape: ", key_cache.shape)
# print("value_cache shape: ", value_cache.shape)
# for block_table in block_tables:
# for block in block_table:
# print(f"key_cache at {block} shape: ", key_cache[block].shape)
# print(f"key_cache at {block}: ", key_cache[block])
# print(f"value_cache at {block} shape: ", value_cache[block].shape)
# print(f"value_cache at {block}: ", value_cache[block])
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
Expand All @@ -339,7 +316,6 @@ def paged_attention(
kv_cache_dtype,
)
else:
# print("v2")
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
Expand All @@ -353,16 +329,6 @@ def paged_attention(
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
# print("output: ", output)
# print("query: ", query)
# print("num_kv_heads: ", num_kv_heads)
# print("scale: ", scale)
# print("block_tables: ", block_tables)
# print("context_lens: ", context_lens)
# print("block_size: ", block_size)
# print("max_context_len: ", max_context_len)
# print("alibi_slopes: ", alibi_slopes)
# print("custom_bias: ", custom_bias)
ops.paged_attention_v2(
output,
exp_sums,
Expand Down
37 changes: 0 additions & 37 deletions vllm/model_executor/layers/enc_dec_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,12 @@ def forward(
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_heads, self.head_size)
# print("query shape: ", query.shape)
if input_metadata.attn_bias is None:
input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
# padded_len = (seq_len + 7) // 8 * 8
# pad_len = padded_len - seq_len
# input_metadata.attn_bias = F.pad(input_metadata.attn_bias, (0, pad_len))
# print("attention bias padded shape: ", input_metadata.attn_bias.shape)

input_metadata.attn_bias = input_metadata.attn_bias[:, :, :, :seq_len]

# print("attention bias shape: ", input_metadata.attn_bias.shape)
# Normal attention
out = xops.memory_efficient_attention_forward(
query,
Expand Down Expand Up @@ -135,42 +127,28 @@ def forward(
Output tensor.
"""

# print("key shape pre view: ", key.shape)
# print("value shape pre view: ", value.shape)

batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
# print("key shape: ", key.shape)
# print("key: ", key)
# print("value shape: ", value.shape)
# print("value: ", value)
# print("slot mapping: ", input_metadata.slot_mapping[:, -1].flatten())
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
# print("key_cache before: ", key_cache)
# print("value_cache before: ", value_cache)

cache_ops.reshape_and_cache(
key, value, key_cache, value_cache,
input_metadata.slot_mapping[:, -1].flatten().contiguous(),
input_metadata.kv_cache_dtype)

# print("key_cache after: ", key_cache)
# print("value_cache after: ", value_cache)

max_prompt_len = input_metadata.prompt_lens.max().item()
block_size = value_cache.shape[3]
prompt_table_len = (max_prompt_len + block_size - 1) // block_size
block_tables = input_metadata.block_tables[:,
prompt_table_len:].contiguous(
)
# print("decoder self attention block_tables", block_tables)
output = paged_attention(
query=query,
key_cache=key_cache,
Expand Down Expand Up @@ -222,18 +200,10 @@ def forward(
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
# print("key shape pre view: ", key.shape)
key = key.view(-1, self.num_heads, self.head_size)
# print("key_shape: ", key.shape)
# print("key sum", key.sum((1, 2)))
if value is not None:
# print("value shape pre view: ", value.shape)
value = value.view(-1, self.num_heads, self.head_size)
# print("value_shape: ", value.shape)
# print("value sum", value.sum((1, 2)))

# print("slot mapping: ", input_metadata.slot_mapping[:, :-1].flatten().shape)
# print("slot mapping: ", input_metadata.slot_mapping[:, :-1].flatten())
# Reshape the keys and values and store them in the cache.
# It only happens during the first pass.
if (input_metadata.is_prompt and key_cache is not None
Expand All @@ -248,14 +218,7 @@ def forward(
input_metadata.kv_cache_dtype,
)

# for slot in input_metadata.slot_mapping[:, :-1].flatten():
# if slot != -1:
# block_number = slot//16;
# block_offset = slot%16;
# print(f"key_cache sum at {slot}: ", key_cache[block_number, :, :, block_offset, :].sum())
# print(f"value_cache sum at {slot}: ", value_cache[block_number, :, :, block_offset].sum())
max_prompt_len = input_metadata.prompt_lens.int().max().item()
# print("max_prompt_len: ", max_prompt_len)
block_size = value_cache.shape[3]
prompt_table_len = (max_prompt_len + block_size - 1) // block_size
block_tables = input_metadata.block_tables[:, :
Expand Down
36 changes: 4 additions & 32 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ def forward(
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]:
# print("hidden_states shape: ", hidden_states.shape)
# print("hidden_states: ", hidden_states)

# Get the hidden states that we use for sampling.
if self.logits_as_hidden_states:
logits = hidden_states
Expand All @@ -70,9 +67,6 @@ def forward(
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)

# print("Logits shape: ", logits.shape)
# print("Logits: ", logits)

# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# the `embedding` weight is distributed across TP workers.
Expand All @@ -83,12 +77,9 @@ def forward(
assert logits is not None
_, vocab_size = logits.shape

# print("Logits shape: ", logits.shape)
# print("Logits: ", logits)
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
# print("Logits shape: ", logits.shape)
# print("Logits: ", logits)

# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
Expand All @@ -101,23 +92,18 @@ def forward(
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)
# print("Logits shape: ", logits.shape)
# print("Logits: ", logits)

# Apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
# print("Logits shape: ", logits.shape)
# print("Logits: ", logits)

if do_top_p_top_k:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
# print("Logits shape: ", logits.shape)
# print("Logits: ", logits)

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)

# print("Logits shape: ", logits.shape)
# print("Logits: ", logits)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Expand All @@ -126,18 +112,10 @@ def forward(
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# Sample the next tokens.
# print("Probs shape: ", probs.shape)
# print("Probs: ", probs)
# print("Logprobs shape: ", logprobs.shape)
# print("Logprobs: ", logprobs)

sample_results = _sample(probs, logprobs, sampling_metadata)
# Get the logprobs query results.
# print("Sample results: ", sample_results)
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
# print("Prompt logprobs: ", prompt_logprobs)
# print("Sample logprobs: ", sample_logprobs)
return _build_sampler_output(sample_results, sampling_metadata,
prompt_logprobs, sample_logprobs)

Expand Down Expand Up @@ -400,8 +378,6 @@ def _sample(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> List[Tuple[List[int], List[int]]]:
# print("probs: ", probs)
# print("logprobs: ", logprobs)
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
Expand All @@ -417,15 +393,11 @@ def _sample(
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type]
# print("sampling_type: ", sampling_type)
# print("sample_indices: ", sample_indices)
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_ids = categorized_seq_group_ids[sampling_type]
# print("seq_group_ids: ", seq_group_ids)
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
# print("seq_groups: ", seq_groups)
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices)
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,8 @@ def forward(

def sample(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata):
# logger.info(f"decoder_outputs: {decoder_outputs}")
next_tokens = self.sampler(self.shared.weight, hidden_states,
sampling_metadata)
# logger.info(f"next_tokens: {next_tokens}")
return next_tokens

def load_weights(
Expand Down
19 changes: 9 additions & 10 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ def __init__(
# Set enforce_eager to True for Neuron backend, to avoid capturing graph
if self.device_config.is_neuron:
self.model_config.enforce_eager = True
self.is_encoder_decoder = getattr(self.model_config.hf_config,
"is_encoder_decoder", False)

# Unpack HF is_encoder_decoder config attribute
# NOTE: must handle "self.model_config is None" case imposed by certain tests i.e. test_prepare_prompt()
# In the None case, default to is_encoder_decoder == False since vLLM decoder-only mode is known to handle
# the None case correctly.
self.is_encoder_decoder = False if self.model_config is None else \
getattr(self.model_config.hf_config, "is_encoder_decoder", False)

def load_model(self) -> None:
self.model = get_model(self.model_config,
Expand Down Expand Up @@ -216,8 +221,6 @@ def _prepare_prompt(
block_tables.append(block_table)
max_block_table_len = max(max_block_table_len,
len(block_table))
# print("slot_mapping: ", slot_mapping)
# print("block_tables: ", block_tables)
max_prompt_len = max(subquery_lens)
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
Expand Down Expand Up @@ -273,7 +276,7 @@ def _prepare_prompt(
max_len=max_prompt_block_table_len,
pad=0,
dtype=torch.int,
)
device=self.device)
start_loc_tensor = torch.arange(0,
len(prompt_lens) * max_prompt_len,
max_prompt_len,
Expand Down Expand Up @@ -521,20 +524,18 @@ def _prepare_sample(
dtype=torch.long,
target_device=self.device,
pin_memory=pin_memory)
# print("selected_token_indices: ", selected_token_indices)
categorized_sample_indices = {
t: _async_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=pin_memory)
for t, seq_ids in categorized_sample_indices.items()
}
# print("categorized_sample_indices: ", categorized_sample_indices)

seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
# print("selected_token_indices: ", selected_token_indices)

sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
Expand Down Expand Up @@ -656,8 +657,6 @@ def execute_model(
kv_caches=kv_caches,
input_metadata=input_metadata,
)
# print("hidden_states shape: ", hidden_states.shape)
# print("hidden_states: ", hidden_states)

# Sample the next token.
output = self.model.sample(
Expand Down

0 comments on commit 9f20ccf

Please sign in to comment.