Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,24 +732,21 @@ def _callback(self, x, *, buffer, done_generating):
print("".join(buffer), end="", flush=True)
buffer.clear()
# print(, end='', flush=True)

def chat(
self,
generator_args: GeneratorArgs,
):
if generator_args.chat_mode:
print("Starting Interactive Chat")
def _gen_model_input(self, prompt: str, image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple:
assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment"
assert image_prompts is None or len(image_prompts) <= 1, "At most one image is supported at the moment"

if image_prompts and isinstance(image_prompts[0], str):
images = [Image.open(image_prompts[0])]
else:
images = image_prompts

if self.model.config.model_type == ModelType.Flamingo:
assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models"

is_multimodal = generator_args.image_prompts is not None
content = [{"type": "text", "content": generator_args.prompt}]
is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]

if is_multimodal:
print("Image prompts", generator_args.image_prompts)

# Support for just the first image prompt for now
images = [Image.open(generator_args.image_prompts[0])]
content = [{"type": "image", "content": images[0]}] + content

messages = [
Expand Down Expand Up @@ -783,7 +780,7 @@ def chat(
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + generator_args.max_new_tokens
total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
Expand All @@ -792,10 +789,22 @@ def chat(
)
else:
encoded = self.encode_tokens(
generator_args.prompt, bos=True, device=self.builder_args.device
prompt, bos=True, device=self.builder_args.device
)
logging.debug(encoded)
batch = None

logging.debug(encoded)
return encoded, batch


def chat(
self,
generator_args: GeneratorArgs,
):
if generator_args.chat_mode:
print("Starting Interactive Chat")

encoded, batch = self._gen_model_input(generator_args.prompt, generator_args.image_prompts, generator_args.max_new_tokens)

model_size = sum(
[
Expand Down
147 changes: 31 additions & 116 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@

from PIL import Image

from torchtune.data import Message, padded_collate_tiled_images_and_mask

from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform

from torchchat.cli.download import is_model_downloaded, load_model_configs
from torchchat.generate import Generator, GeneratorArgs

from torchchat.utils.build_utils import device_sync

from torchtune.data import Message, padded_collate_tiled_images_and_mask

from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform


"""Dataclasses defined around the objects used the OpenAI API Chat specification.

Expand Down Expand Up @@ -296,79 +296,44 @@ def __init__(self, *args, **kwargs):
f"{self.builder_args.device}_{self.builder_args.precision}"
)

def _openai_messages_to_torchtune_messages(
self, messages: List[_AbstractMessage]
def _gen_model_inputs_from_openai_completion_request(
self, completion_request: CompletionRequest
) -> List[Message]:
"""Convert a list of OpenAI API messages to a list of TorchTune messages.
"""Generate model inputs from an OpenAI completion request.

Args:
messages: A list of OpenAI API messages.
completion_request: Request object with prompt and other parameters.

Returns:
A list of Torchtune Messages.
Modle inputs.
"""
torchtune_messages = []
messages = completion_request.messages

prompt = None
images = None

for message in messages:
torchtune_contents = []
if isinstance(message["content"], list):
for content_dict in message["content"]:
converted_content = []
if content_dict["type"] == "text":
converted_content.append(
{"type": "text", "content": content_dict["text"]}
)
assert (
prompt is None
), "At most one text prompt is supported for each request"
prompt = content_dict["text"]
elif content_dict["type"] == "image_url":
assert (
images is None
), "At most one image is supported at the moment"

base64_decoded = base64.b64decode(
content_dict["image_url"].split(";base64,")[1]
)
image = Image.open(BytesIO(base64_decoded))
converted_content.append(
{
"type": "image",
"content": image,
}
content_dict["image_url"].split(";base64,")[1]
)
torchtune_messages.append(
Message(role=message["role"], content=converted_content, eot=False)
)
return torchtune_messages
images = [Image.open(BytesIO(base64_decoded))]

def _openai_messages_to_torchtune(
self, messages: List[_AbstractMessage]
) -> List[Message]:
"""Convert a list of OpenAI API messages to a list of TorchTune messages.
assert prompt is not None, "Text prompt must be specified in the request"

Args:
messages: A list of OpenAI API messages.

Returns:
A list of Torchtune Messages.
"""
torchtune_messages = []
for message in messages:
torchtune_contents = []
if isinstance(message["content"], list):
for content in message["content"]:
if isinstance(content, dict):
if content["type"] == "image_url":
torchtune_contents.append({"type": "image"})
elif content["type"] == "image_file":
torchtune_contents.append({"type": "image"})
elif content["type"] == "text":
torchtune_contents.append(
{"type": "text", "content": content["text"]}
)
elif isinstance(content, str):
torchtune_contents.append({"type": "text", "text": content})
else:
torchtune_contents.append(
{"type": "text", "content": message["content"]}
)
torchtune_messages.append(
Message(role=message["role"], content=torchtune_contents, eot=False)
)
torchtune_messages.append(Message(role="assistant", content="", eot=False))
return torchtune_messages
return self._gen_model_input(prompt, images, completion_request.max_tokens)

def chunked_completion(self, completion_request: CompletionRequest):
"""Handle a chat completion request and yield a chunked response.
Expand Down Expand Up @@ -396,63 +361,13 @@ def chunked_completion(self, completion_request: CompletionRequest):
# Initialize counters for chunk responses and encode the prompt.
id = str(uuid.uuid4())

idx = 0
images = []

device_sync(device=self.builder_args.device)
for message in completion_request.messages:
contents = message["content"]
if isinstance(contents, list):
for content in message["content"]:
if content["type"] == "image_url":
base64_decoded = base64.b64decode(
content["image_url"].split(";base64,")[1]
)
images.append(Image.open(BytesIO(base64_decoded)))
print("images:", len(images), flush=True)
if len(images) > 0:
transform = llama3_2_vision_transform(
str(self.tokenizer_args.tokenizer_path)
)
torchtune_messages = self._openai_messages_to_torchtune_messages(
completion_request.messages
)
data = transform(
{"images": images, "messages": torchtune_messages}, inference=True
)
seq_len = len(data["tokens"])
total_response_length = seq_len + completion_request.max_tokens
causal_mask = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
)
input_pos = torch.arange(total_response_length)

with torch.no_grad():
with torch.device(self.builder_args.device):
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision)
batch["causal_mask"] = causal_mask
batch["input_pos"] = input_pos[None, :seq_len]
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]

#batch = padded_collate([data], self.builder_args.device)
encoded = batch["tokens"].view(-1)
else:
tokens = self.chat_formatter.encode_dialog_prompt(
dialog=[
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
]
)
print("tokens:", self.tokenizer.decode(tokens), flush=True)
encoded = torch.tensor(
tokens, dtype=torch.int, device=self.builder_args.device
)
batch = None

encoded, batch = self._gen_model_inputs_from_openai_completion_request(
completion_request
)

idx = 0
start_pos = 0

generator_args = GeneratorArgs(
Expand Down
Loading