Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 42 additions & 50 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@

import pytest
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
LlavaConfig, LlavaForConditionalGeneration)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.inputs import PromptInputs
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import MultiModalData
from vllm.sequence import MultiModalData, SampleLogprobs

logger = init_logger(__name__)

Expand Down Expand Up @@ -188,10 +189,11 @@ def generate(
prompts: List[str],
images: Optional[List[Image.Image]] = None,
**kwargs,
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
assert len(prompts) == len(images)

outputs: List[Tuple[List[List[int]], List[str]]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
Expand All @@ -201,17 +203,13 @@ def generate(
processor_kwargs["images"] = images[i]

inputs = self.processor(**processor_kwargs)
inputs = {
key: value.cuda() if value is not None else None
for key, value in inputs.items()
}

output_ids = self.model.generate(
**inputs,
**inputs.to("cuda"),
use_cache=True,
**kwargs,
)
output_str = self.tokenizer.batch_decode(
output_str = self.processor.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
Expand All @@ -224,23 +222,22 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional["torch.Tensor"] = None,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
images=images)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
outputs[i] = (output_ids[0], output_str[0])
return outputs

return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[List[int]], List[str]]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
Expand Down Expand Up @@ -282,9 +279,7 @@ def generate_greedy_logprobs(
if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
all_logprobs.append(seq_logprobs)
return all_logprobs
Expand All @@ -294,10 +289,10 @@ def generate_greedy_logprobs_limit(
prompts: List[str],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str]]:
all_logprobs = []
all_output_ids = []
all_output_strs = []
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []

for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
Expand All @@ -310,7 +305,7 @@ def generate_greedy_logprobs_limit(
return_dict_in_generate=True,
)

seq_logprobs = []
seq_logprobs: List[torch.Tensor] = []
for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
Expand All @@ -321,13 +316,11 @@ def generate_greedy_logprobs_limit(
None) is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)

# convert to dict
seq_logprobs_lst = []
seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
Expand Down Expand Up @@ -372,13 +365,13 @@ def __init__(
tokenizer_name: Optional[str] = None,
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len=1024,
max_model_len: int = 1024,
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space=4,
swap_space: int = 4,
**kwargs,
) -> None:
self.model = LLM(
Expand All @@ -399,32 +392,31 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional["torch.Tensor"] = None,
) -> List[Tuple[List[int], str]]:
images: Optional[torch.Tensor] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == images.shape[0]
assert len(prompts) == len(images)

prompt_inputs: List[PromptInputs] = []
prompt_inputs: List[TextPrompt] = []
for i, prompt in enumerate(prompts):
image = None if images is None else images[i:i + 1]
mm_data = None if image is None else MultiModalData(
type=MultiModalData.Type.IMAGE,
data=image,
)
prompt = TextPrompt(prompt=prompt)
if images is not None:
prompt["multi_modal_data"] = MultiModalData(
type=MultiModalData.Type.IMAGE,
data=images[i:i + 1],
)

prompt_inputs.append({
"prompt": prompt,
"multi_modal_data": mm_data,
})
prompt_inputs.append(prompt)

req_outputs = self.model.generate(prompt_inputs,
sampling_params=sampling_params)
outputs = []

outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids = []
req_sample_output_strs = []
req_sample_output_ids: List[List[int]] = []
req_sample_output_strs: List[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
Expand All @@ -437,12 +429,12 @@ def generate_w_logprobs(
self,
prompts: List[str],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None

req_outputs = self.model.generate(prompts,
sampling_params=sampling_params)
outputs = []
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
Expand All @@ -467,7 +459,7 @@ def generate_greedy_logprobs(
prompts: List[str],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs)
Expand All @@ -481,7 +473,7 @@ def generate_beam_search(
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[List[int]], List[str]]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
Expand Down