# BakLLaVa (Mistral-7B) Experiment

In [6]:
import yaml
import torch
import os

def load_config(config_path,config_name):
    with open(os.path.join(config_path, config_name)) as file:
        config = yaml.safe_load(file)
    return config

config = load_config("../","config.yaml")

In [7]:
# Constants & Seed
SEED = config["seed"]
torch.manual_seed(SEED)

# Inputs
IMG_PATH = config["image_path"]
IMG_ID = IMG_PATH.split('/')[-1].split('.')[0]

# Prompt
LR_PROMPT_TYPE = config['prompt']['lr']['filename']
QG_PROMPT_TYPE = config["prompt"]["qg"]["filename"]
LR_PROMPT_PATH = f"{config['prompt']['lr']['parent']}{LR_PROMPT_TYPE}"
QG_PROMPT_PATH = f'{config["prompt"]["qg"]["parent"]}{QG_PROMPT_TYPE}'
with open(LR_PROMPT_PATH, "r") as file:
    LR_PROMPT= file.read()
with open(QG_PROMPT_PATH,"r") as file:
    QG_PROMPT = file.read()

# Params
MODEL_PATH = config["bakllava"]["model_path"]
PAIR_NUM = config["bakllava"]["params"]["pair_count"]

# Result
LR_RESULT_PARENT_PATH = config["result"]["bakllava"]["lr_path"]
QG_RESULT_PARENT_PATH = config["result"]["bakllava"]["qg_path"]
JSON_PARENT_PATH = config["result"]["bakllava"]["json_path"]

In [8]:
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_PATH, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
#     load_in_4bit=True
).to(0)
processor = AutoProcessor.from_pretrained(MODEL_PATH)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.94it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [29]:
lr_prompt = f"USER: <image>\n{LR_PROMPT.format(number = 5)}\nASSISTANT:"


In [45]:
from datetime import datetime

raw_image = Image.open(IMG_PATH)

# lr_prompt = f"USER: <image>\n{LR_PROMPT.format(number = 5)}\nASSISTANT:"
# lr_inputs = processor(lr_prompt, raw_image, return_tensors='pt') \
#             .to(0, torch.float16)
# lr_output_raw = model.generate(**lr_inputs, max_new_tokens=200, do_sample=False)
# lr_output = processor.decode(lr_output_raw[0], skip_special_tokens=True)
# lr_output_trunc = lr_output[lr_output.index("ASSISTANT:") + 11:]


# qg_prompt = f"USER: <image>\n{QG_PROMPT.format(desc = lr_output_trunc, number = PAIR_NUM)}\nASSISTANT:"
# qg_inputs = processor(qg_prompt, raw_image, return_tensors='pt') \
#             .to(0, torch.float16)
# qg_output_raw = model.generate(**qg_inputs, max_new_tokens=200, do_sample=False)
# qg_output = processor.decode(qg_output_raw[0], skip_special_tokens=True)
# qg_output_trunc = qg_output[qg_output.index("ASSISTANT:") + 11:]


story_prompt = f"USER: <image>\nGenerate a 400 words description about the image. AVOID using any external assumptions or informations!\nASSISTANT:"
story_inputs = processor(story_prompt, raw_image, return_tensors='pt') \
            .to(0, torch.float16)
story_output_raw = model.generate(**story_inputs, max_new_tokens = 500, do_sample=False)
story_output = processor.decode(story_output_raw[0], skip_special_tokens=True)
story_output_trunc = story_output[story_output.index("ASSISTANT:") + 11:]



str8_prompt = f"USER: <image>\nWhat are 5 possible questions-answers about the image?\nASSISTANT:"
str8_inputs = processor(str8_prompt, raw_image, return_tensors='pt') \
            .to(0, torch.float16)
str8_output_raw = model.generate(**str8_inputs, max_new_tokens = 500, do_sample=False)
str8_output = processor.decode(str8_output_raw[0], skip_special_tokens=True)
str8_output_trunc = str8_output[str8_output.index("ASSISTANT:") + 11:]


qgstory_prompt = f"USER: <image>\n<PASSAGE>\n{story_output_trunc}\n\n\nBased on the <PASSAGE>, what are 5 possible questions & answers?\nASSISTANT:"
qgstory_inputs = processor(qgstory_prompt, raw_image, return_tensors='pt') \
            .to(0, torch.float16)
qgstory_output_raw = model.generate(**qgstory_inputs, max_new_tokens = 500, do_sample=False)
qgstory_output = processor.decode(qgstory_output_raw[0], skip_special_tokens=True)
qgstory_output_trunc = qgstory_output[qgstory_output.index("ASSISTANT:") + 11:]


timestamp = datetime.now().strftime('%d%m%H%M')
LR_RESULT_FILENAME = f"{IMG_ID}_{datetime.now().strftime('%d%m%H%M')}.txt" 
QG_RESULT_FILENAME = f"{IMG_ID}_{datetime.now().strftime('%d%m%H%M')}.txt" 
LR_RESULT_PATH = LR_RESULT_PARENT_PATH + LR_RESULT_FILENAME
QG_RESULT_PATH = QG_RESULT_PARENT_PATH + QG_RESULT_FILENAME

with open(LR_RESULT_PATH,"w") as f:
    f.write(lr_output_trunc)
with open(QG_RESULT_PATH,"w") as f:
    f.write(qg_output_trunc)
with open(f"../result/inference/BakLLaVa/story/{IMG_ID}_{timestamp}_story.txt","w") as f:
    f.write(story_output_trunc)
with open(f"../result/inference/BakLLaVa/qgstory/{IMG_ID}_{timestamp}_qgstory.txt","w") as f:
    f.write(qgstory_output_trunc)
with open(f"../result/inference/BakLLaVa/str8/{IMG_ID}_{timestamp}_str8.txt","w") as f:
    f.write(str8_output_trunc)