diff --git a/torchchat/generate.py b/torchchat/generate.py index 87902e180..c38fcaff5 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -732,24 +732,21 @@ def _callback(self, x, *, buffer, done_generating): print("".join(buffer), end="", flush=True) buffer.clear() # print(, end='', flush=True) - - def chat( - self, - generator_args: GeneratorArgs, - ): - if generator_args.chat_mode: - print("Starting Interactive Chat") + + def _gen_model_input(self, prompt: str, image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple: + assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment" + if image_prompts and isinstance(image_prompts[0], str): + images = [Image.open(image_prompts[0])] + else: + images = image_prompts if self.model.config.model_type == ModelType.Flamingo: + assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models" - is_multimodal = generator_args.image_prompts is not None - content = [{"type": "text", "content": generator_args.prompt}] + is_multimodal = images is not None + content = [{"type": "text", "content": prompt}] if is_multimodal: - print("Image prompts", generator_args.image_prompts) - - # Support for just the first image prompt for now - images = [Image.open(generator_args.image_prompts[0])] content = [{"type": "image", "content": images[0]}] + content messages = [ @@ -783,7 +780,7 @@ def chat( seq_len = encoded.size(0) batch = {} - total_response_length = seq_len + generator_args.max_new_tokens + total_response_length = seq_len + max_new_tokens batch["causal_mask"] = torch.tril( torch.ones( size=(total_response_length, total_response_length), @@ -792,10 +789,22 @@ def chat( ) else: encoded = self.encode_tokens( - generator_args.prompt, bos=True, device=self.builder_args.device + prompt, bos=True, device=self.builder_args.device ) - logging.debug(encoded) batch = None + + logging.debug(encoded) + return encoded, batch + + + def chat( + self, + generator_args: GeneratorArgs, + ): + if generator_args.chat_mode: + print("Starting Interactive Chat") + + encoded, batch = self._gen_model_input(generator_args.prompt, generator_args.image_prompts, generator_args.max_new_tokens) model_size = sum( [ diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index aa63782fb..8cdd8849d 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -19,15 +19,15 @@ from PIL import Image -from torchtune.data import Message, padded_collate_tiled_images_and_mask - -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform - from torchchat.cli.download import is_model_downloaded, load_model_configs from torchchat.generate import Generator, GeneratorArgs from torchchat.utils.build_utils import device_sync +from torchtune.data import Message, padded_collate_tiled_images_and_mask + +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform + """Dataclasses defined around the objects used the OpenAI API Chat specification. @@ -296,79 +296,44 @@ def __init__(self, *args, **kwargs): f"{self.builder_args.device}_{self.builder_args.precision}" ) - def _openai_messages_to_torchtune_messages( - self, messages: List[_AbstractMessage] + def _gen_model_inputs_from_openai_completion_request( + self, completion_request: CompletionRequest ) -> List[Message]: - """Convert a list of OpenAI API messages to a list of TorchTune messages. + """Generate model inputs from an OpenAI completion request. Args: - messages: A list of OpenAI API messages. + completion_request: Request object with prompt and other parameters. Returns: - A list of Torchtune Messages. + Modle inputs. """ - torchtune_messages = [] + messages = completion_request.messages + + prompt = None + images = None + for message in messages: torchtune_contents = [] if isinstance(message["content"], list): for content_dict in message["content"]: - converted_content = [] if content_dict["type"] == "text": - converted_content.append( - {"type": "text", "content": content_dict["text"]} - ) + assert ( + prompt is None + ), "At most one text prompt is supported for each request" + prompt = content_dict["text"] elif content_dict["type"] == "image_url": + assert ( + images is None + ), "At most one image is supported at the moment" + base64_decoded = base64.b64decode( - content_dict["image_url"].split(";base64,")[1] - ) - image = Image.open(BytesIO(base64_decoded)) - converted_content.append( - { - "type": "image", - "content": image, - } + content_dict["image_url"].split(";base64,")[1] ) - torchtune_messages.append( - Message(role=message["role"], content=converted_content, eot=False) - ) - return torchtune_messages + images = [Image.open(BytesIO(base64_decoded))] - def _openai_messages_to_torchtune( - self, messages: List[_AbstractMessage] - ) -> List[Message]: - """Convert a list of OpenAI API messages to a list of TorchTune messages. + assert prompt is not None, "Text prompt must be specified in the request" - Args: - messages: A list of OpenAI API messages. - - Returns: - A list of Torchtune Messages. - """ - torchtune_messages = [] - for message in messages: - torchtune_contents = [] - if isinstance(message["content"], list): - for content in message["content"]: - if isinstance(content, dict): - if content["type"] == "image_url": - torchtune_contents.append({"type": "image"}) - elif content["type"] == "image_file": - torchtune_contents.append({"type": "image"}) - elif content["type"] == "text": - torchtune_contents.append( - {"type": "text", "content": content["text"]} - ) - elif isinstance(content, str): - torchtune_contents.append({"type": "text", "text": content}) - else: - torchtune_contents.append( - {"type": "text", "content": message["content"]} - ) - torchtune_messages.append( - Message(role=message["role"], content=torchtune_contents, eot=False) - ) - torchtune_messages.append(Message(role="assistant", content="", eot=False)) - return torchtune_messages + return self._gen_model_input(prompt, images, completion_request.max_tokens) def chunked_completion(self, completion_request: CompletionRequest): """Handle a chat completion request and yield a chunked response. @@ -396,63 +361,13 @@ def chunked_completion(self, completion_request: CompletionRequest): # Initialize counters for chunk responses and encode the prompt. id = str(uuid.uuid4()) - idx = 0 - images = [] - device_sync(device=self.builder_args.device) - for message in completion_request.messages: - contents = message["content"] - if isinstance(contents, list): - for content in message["content"]: - if content["type"] == "image_url": - base64_decoded = base64.b64decode( - content["image_url"].split(";base64,")[1] - ) - images.append(Image.open(BytesIO(base64_decoded))) - print("images:", len(images), flush=True) - if len(images) > 0: - transform = llama3_2_vision_transform( - str(self.tokenizer_args.tokenizer_path) - ) - torchtune_messages = self._openai_messages_to_torchtune_messages( - completion_request.messages - ) - data = transform( - {"images": images, "messages": torchtune_messages}, inference=True - ) - seq_len = len(data["tokens"]) - total_response_length = seq_len + completion_request.max_tokens - causal_mask = torch.tril( - torch.ones( - size=(total_response_length, total_response_length), - dtype=torch.bool, - ) - ) - input_pos = torch.arange(total_response_length) - - with torch.no_grad(): - with torch.device(self.builder_args.device): - batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) - batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision) - batch["causal_mask"] = causal_mask - batch["input_pos"] = input_pos[None, :seq_len] - batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] - - #batch = padded_collate([data], self.builder_args.device) - encoded = batch["tokens"].view(-1) - else: - tokens = self.chat_formatter.encode_dialog_prompt( - dialog=[ - {"role": message["role"], "content": message["content"]} - for message in completion_request.messages - ] - ) - print("tokens:", self.tokenizer.decode(tokens), flush=True) - encoded = torch.tensor( - tokens, dtype=torch.int, device=self.builder_args.device - ) - batch = None + encoded, batch = self._gen_model_inputs_from_openai_completion_request( + completion_request + ) + + idx = 0 start_pos = 0 generator_args = GeneratorArgs(