diff --git a/lavis/models/blip_models/blip_caption.py b/lavis/models/blip_models/blip_caption.py index 26f0690a..095699fb 100644 --- a/lavis/models/blip_models/blip_caption.py +++ b/lavis/models/blip_models/blip_caption.py @@ -183,6 +183,10 @@ def generate( prompt = self.tokenizer(prompt, return_tensors="pt").to(self.device) prompt.input_ids[:, 0] = self.tokenizer.bos_token_id prompt.input_ids = prompt.input_ids[:, :-1] + + # prepare prompt for beam search + if not use_nucleus_sampling: + prompt.input_ids = torch.repeat_interleave(prompt.input_ids, num_beams, dim=0) # get decoded text decoder_out = self.text_decoder.generate_from_encoder(