From 67680f88c8e6f15d006cfede450e452eac58dc80 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Wed, 12 Apr 2023 12:11:23 +0000 Subject: [PATCH] fix on blip_caption.py: prepare prompt for beam search, aovid tensor mismatch in later XMLBert decoder. --- lavis/models/blip_models/blip_caption.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lavis/models/blip_models/blip_caption.py b/lavis/models/blip_models/blip_caption.py index 26f0690a..095699fb 100644 --- a/lavis/models/blip_models/blip_caption.py +++ b/lavis/models/blip_models/blip_caption.py @@ -183,6 +183,10 @@ def generate( prompt = self.tokenizer(prompt, return_tensors="pt").to(self.device) prompt.input_ids[:, 0] = self.tokenizer.bos_token_id prompt.input_ids = prompt.input_ids[:, :-1] + + # prepare prompt for beam search + if not use_nucleus_sampling: + prompt.input_ids = torch.repeat_interleave(prompt.input_ids, num_beams, dim=0) # get decoded text decoder_out = self.text_decoder.generate_from_encoder(