diff --git a/torchchat/generate.py b/torchchat/generate.py index d69989161..87902e180 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -353,27 +353,34 @@ def prefill( width = x.size(1) assert input_pos.size(0) == width - if batch is not None: + if self.model.config.model_type == ModelType.Flamingo: + assert batch is not None, "Flamingo requires batch" + # TODO: Verify sequential prefill works with multimodal models - tokens = batch["tokens"] + is_multimodal = True if 'encoder_input' in batch: encoder_input = batch['encoder_input'] + encoder_mask = batch["encoder_mask"] + is_multimodal = True else: encoder_input = None + encoder_mask = None + is_multimodal = False - seq_len = tokens.size(1) + seq_len = x.size(1) mask = batch["causal_mask"][None, :seq_len] - encoder_mask = batch["encoder_mask"] input_pos = input_pos.view(1, -1) - logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1] + logits = model(tokens=x, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1] + + if is_multimodal: + batch["encoder_mask"] = batch["encoder_mask"][:, -1:] + return tune_sample(logits, temperature=0, top_k=500) elif sequential_prefill: for i in range(width): x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) # logging.debug(f" x: {x_sliced}, input_pos: {ip_sliced}") - logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i]) - elif self.model.config.model_type == ModelType.Flamingo: - assert False, "Flamingo requires batch" + logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])da else: # input_pos: [B, S] logits = model(x, input_pos) @@ -397,7 +404,7 @@ def decode_one_token( if model.config.model_type == ModelType.Flamingo: assert batch is not None, "Flamingo requires batch" mask = batch["causal_mask"][None, input_pos.item(), None, :] - encoder_mask = batch["encoder_mask"][:, -1:] + encoder_mask = batch["encoder_mask"] if "encoder_mask" in batch else None logits = model(x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos)[:, -1:] else: logits = model(x, input_pos) @@ -733,18 +740,22 @@ def chat( if generator_args.chat_mode: print("Starting Interactive Chat") - if generator_args.image_prompts is not None: - print("Image prompts", generator_args.image_prompts) + if self.model.config.model_type == ModelType.Flamingo: + + is_multimodal = generator_args.image_prompts is not None + content = [{"type": "text", "content": generator_args.prompt}] + + if is_multimodal: + print("Image prompts", generator_args.image_prompts) + + # Support for just the first image prompt for now + images = [Image.open(generator_args.image_prompts[0])] + content = [{"type": "image", "content": images[0]}] + content - # Support for just the first image prompt for now - images = [Image.open(generator_args.image_prompts[0])] messages = [ Message( role="user", - content=[ - {"type": "image", "content": images[0]}, - {"type": "text", "content": generator_args.prompt}, - ], + content=content, eot=True, ), Message(role="assistant", content=""), @@ -752,12 +763,26 @@ def chat( transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path)) - with torch.device(device=self.builder_args.device), set_default_dtype(self.dtype): + device = torch.device(device=self.builder_args.device) + + with device, set_default_dtype(self.dtype): data = transform({"messages": messages}, inference=True) - batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) - # set_default_dtype can not handle the dtype of the image tensor inside the batch; need to manually cast it - batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype) - seq_len = len(data["tokens"]) + + if is_multimodal: + batch = padded_collate_tiled_images_and_mask( + [data], pad_direction="left", pad_max_images=1 + ) + encoded = batch.pop("tokens").to(device).view(-1) + seq_len = encoded.size(0) + batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] + batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype) + else: + encoded = torch.tensor( + data["tokens"], device=device + ).view(-1) + seq_len = encoded.size(0) + batch = {} + total_response_length = seq_len + generator_args.max_new_tokens batch["causal_mask"] = torch.tril( torch.ones( @@ -765,9 +790,6 @@ def chat( dtype=torch.bool, ) ) - batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] - encoded = batch["tokens"].view(-1) - else: encoded = self.encode_tokens( generator_args.prompt, bos=True, device=self.builder_args.device