In [1]:
# Libraries
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import torch
from PIL import Image
import requests
import time

# Constants
LR_PROMPT_PATH = "../prompt/list-then-rewrite.txt"
QG_PROMPT_PATH = "../prompt/question-generation.txt"

with open(LR_PROMPT_PATH, "r") as file:
    LR_PROMPT= file.read()

with open(QG_PROMPT_PATH,"r") as file:
    QG_PROMPT = file.read()

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
# image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
imageFile = "../dataset/003.jpg"
image = Image.open(imageFile)


# List-then-Rewrite
lr_input = processor(
    images=image, 
    text=LR_PROMPT.format(number = "10"), 
    return_tensors="pt"
).to(device)

s = time.time()
lr_output = model.generate(
        **lr_input,
        do_sample=False,
        num_beams=5,
        max_length=256,
        min_length=1,
        top_p=0.9,
        repetition_penalty=1.5,
        length_penalty=1.0,
        temperature=1,
)
lr_generated_out = processor.batch_decode(lr_output, skip_special_tokens=True)[0].strip()
e = time.time()
lr_time = e - s



# Question Generation
qg_input = processor(
    images=image, 
    text=QG_PROMPT.format(desc = lr_generated_out), 
    return_tensors="pt"
).to(device)

s = time.time()
qg_output = model.generate(
        **qg_input,
        do_sample=False,
        num_beams=5,
        max_length=256,
        min_length=1,
        top_p=0.9,
        repetition_penalty=1.5,
        length_penalty=1.0,
        temperature=1,
)
qg_generated_out = processor.batch_decode(qg_output, skip_special_tokens=True)[0].strip()
e = time.time()
qg_time = e - s

Loading checkpoint shards: 100%|██████████| 4/4 [02:26<00:00, 36.71s/it]


In [None]:
LR_RESULT_FILENAME = f"LAV2_LR_result_{imageFile.split('/')[-1].split('.')[0]}.txt" 
QG_RESULT_FILENAME = f"LAV2_QG_result_{imageFile.split('/')[-1].split('.')[0]}.txt" 

with open(f"../result/LLaMa-Adapter-V2/{LR_RESULT_FILENAME}", "w") as file:
    # Writing data to a file
    file.write(lr_generated_out)
    file.write("\n\n")
    file.write(f"Processing time : {lr_time}s")

with open(f"../result/LLaMa-Adapter-V2/{QG_RESULT_FILENAME}", "w") as file:
    # Writing data to a file
    file.write(qg_generated_out)
    file.write("\n\n")
    file.write(f"Processing time : {qg_time}s")