## Flan-t5

In [None]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-xl')
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xl')

In [None]:
# get special ids in vocab
print(tokenizer(['yes', 'no', 'Yes', 'No', 'A', 'B', 'C', 'D'])['input_ids'])
# [[4273, 1], [150, 1], [2163, 1], [465, 1], [71, 1], [272, 1], [205, 1], [309, 1]]
constrained_idx1 = [4273, 150]
constrained_idx2 = [2163, 465]

# get the scores of certain tokens
inputs = tokenizer('Is the earth smaller then the basketball ?', return_tensors="pt")
outputs = model.generate(inputs["input_ids"],
                         num_beams=1, 
                         return_dict_in_generate=True,
                         output_scores=True)
print(outputs.sequences)
print(tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True))
print(tokenizer.batch_decode(outputs.sequences))
logits = outputs.scores[1] # skip the bos token
constrained_logits = logits[:, constrained_idx1]
print(torch.softmax(constrained_logits, dim=-1)) 
constrained_logits = logits[:, constrained_idx2]
print(torch.softmax(constrained_logits, dim=-1)) 
constrained_logits = logits[:, constrained_idx1 + constrained_idx2]
print(torch.softmax(constrained_logits, dim=-1)) 

## InstructBLIP-flan-t5

In [None]:
from PIL import Image

import torch
from transformers import AutoProcessor, InstructBlipForConditionalGeneration
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

instructblip_processor = AutoProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
instructblip_model = InstructBlipForConditionalGeneration.from_pretrained(
    "Salesforce/instructblip-flan-t5-xl").to(device)  # instructblip-flan-t5-xl
instructblip_model.eval()

constrained_index = instructblip_processor(text=['Yes', 'No', 'yes', 'no'])["input_ids"]
constrained_index = [x[0] for x in constrained_index]

In [24]:
image = Image.open('camel1.png')
# image.thumbnail((640, 640), Image.Resampling.LANCZOS)
instructblip_inputs = instructblip_processor(
    text='Does "there is any animal in the stocking" correcly describe the image ?',
    images=image.convert('RGB'),
    return_tensors='pt',
).to(device)
outputs = instructblip_model.generate(**instructblip_inputs, 
                                            num_beams=1,
                                            return_dict_in_generate=True,
                                            output_scores=True)
# responses = self.instructblip_processor.batch_decode(outputs, skip_special_tokens=True)
logits = outputs.scores[1]
print(instructblip_processor.batch_decode(outputs.sequences))
contrained_logits1 = logits[:, constrained_index[:2]]
contrained_logits1 = torch.softmax(contrained_logits1, dim=-1)
print(contrained_logits1) # probs of Yes No
contrained_logits2 = logits[:, constrained_index[2:]]
contrained_logits2 = torch.softmax(contrained_logits2, dim=-1)
print(contrained_logits2) # probs of yes no

['<pad> no</s>']
tensor([[0.4022, 0.5978]], device='cuda:0')
tensor([[0.2335, 0.7665]], device='cuda:0')
