# LLaVa Experiment

In [18]:
import yaml
import torch
import os

def load_config(config_path,config_name):
    with open(os.path.join(config_path, config_name)) as file:
        config = yaml.safe_load(file)
    return config

config = load_config("../","config.yaml")

In [19]:
import json

res = {}

with open('sample_eval.json') as f:
    d = json.load(f)

for i in d:
    if i['img_id'] not in res:
        res[i['img_id']] = 1
    else:
        res[i['img_id']] += 1

print(res)

{'sample_001': 30, 'sample_008': 10, 'sample_007': 10, 'sample_005': 20, 'sample_009': 10, 'sample_002': 20}


In [20]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [21]:
# Constants & Seed
SEED = config["seed"]
torch.manual_seed(SEED)

# Inputs
IMG_PATH = config["image_path"]
IMG_ID = IMG_PATH.split('/')[-1].split('.')[0]

# Prompt
LR_PROMPT_TYPE = config['prompt']['lr']['filename']
QG_PROMPT_TYPE = config["prompt"]["qg"]["filename"]
STR8_PROMPT_TYPE = config["prompt"]["str8"]["filename"]
STORY_PROMPT_TYPE = config["prompt"]["qgstory"]["story_filename"]
QGSTORY_PROMPT_TYPE = config["prompt"]["qgstory"]["qg_filename"]

LR_PROMPT_PATH = f"{config['prompt']['lr']['parent']}{LR_PROMPT_TYPE}"
QG_PROMPT_PATH = f'{config["prompt"]["qg"]["parent"]}{QG_PROMPT_TYPE}'
STR8_PROMPT_PATH = f"{config['prompt']['str8']['parent']}{STR8_PROMPT_TYPE}"
STORY_PROMPT_PATH = f"{config['prompt']['qgstory']['parent']}{STORY_PROMPT_TYPE}"
QGSTORY_PROMPT_PATH = f"{config['prompt']['qgstory']['parent']}{QGSTORY_PROMPT_TYPE}"

with open(LR_PROMPT_PATH, "r") as file:
    LR_PROMPT= file.read()
with open(QG_PROMPT_PATH,"r") as file:
    QG_PROMPT = file.read()
with open(STR8_PROMPT_PATH,"r") as file:
    STR8_PROMPT = file.read()
with open(STORY_PROMPT_PATH,"r") as file:
    STORY_PROMPT = file.read()
with open(QGSTORY_PROMPT_PATH,"r") as file:
    QGSTORY_PROMPT = file.read()

# Params
MODEL_NAME = config["llava"]["model_name"]
MODEL_PATH = config["llava"]["model_path"]
PAIR_NUM = config["llava"]["model_params"]["pair_count"]

# Result
LR_RESULT_PARENT_PATH = config["llava"]["result"]["lr_path"].format(model_name = MODEL_NAME)
QG_RESULT_PARENT_PATH = config["llava"]["result"]["qg_path"].format(model_name = MODEL_NAME)
JSON_RESULT_PARENT_PATH = config["llava"]["result"]["json_path"].format(model_name = MODEL_NAME)
STORY_RESULT_PARENT_PATH = config["llava"]["result"]["story_path"].format(model_name = MODEL_NAME)
QGSTORY_RESULT_PARENT_PATH = config["llava"]["result"]["qgstory_path"].format(model_name = MODEL_NAME)
STR8_RESULT_PARENT_PATH = config["llava"]["result"]["str8_path"].format(model_name = MODEL_NAME)

In [8]:
from PIL import Image
import transformers
from transformers import AutoProcessor, LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_PATH, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
#     load_in_4bit=True
).to(0)
processor = AutoProcessor.from_pretrained(MODEL_PATH)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00,  6.80it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [22]:
def inference_llava(model, processor, prompt, img, max_new_tokens=1500, do_sample=False, skip_special_tokens=True) -> str:
    complete_prompt = f"USER: <image>\n{prompt}\nASSISTANT:"
    
    inputs = processor(
        complete_prompt, 
        img, 
        return_tensors = 'pt'
    ).to(0, torch.float16)
    
    raw_output = model.generate(
        **inputs, 
        max_new_tokens = max_new_tokens, 
        do_sample = do_sample
    )
    
    output = processor.decode(raw_output[0], skip_special_tokens = skip_special_tokens)
    output_trunc = output[output.index("ASSISTANT:") + 11:]
    
    return output_trunc

def exec_time(to, tt) -> str:
    time_difference = tt - to

    hours, remainder = divmod(time_difference.seconds, 3600)
    minutes, seconds = divmod(remainder, 60)

    result_format = f"{hours}h{minutes}m{seconds}s"
    
    return result_format

In [23]:
STORY_PROMPT

'Generate a 750 words description about the image. AVOID using any external assumptions or informations!'

In [24]:
from datetime import datetime

raw_image = Image.open(IMG_PATH)

# t_lrqg_o = datetime.now()
# lr_out = inference_llava(
#     model, processor, 
#     LR_PROMPT.format(number = PAIR_NUM), 
#     raw_image
# )

# qg_out = inference_llava(
#     model, processor,
#     QG_PROMPT.format(desc = lr_out, number = PAIR_NUM),
#     raw_image
# )
# t_lrqg_t = datetime.now()



# t_str8_o = datetime.now()
# str8_out = inference_llava(
#     model, processor,
#     STR8_PROMPT.format(number = PAIR_NUM),
#     raw_image
# )
# t_str8_t = datetime.now()



t_qgstory_o = datetime.now()
story_out = inference_llava(
    model, processor,
    STORY_PROMPT,
    raw_image
)

qgstory_out = inference_llava(
    model, processor,
    QGSTORY_PROMPT.format(story = story_out, number = PAIR_NUM),
    raw_image
)
t_qgstory_t = datetime.now()

timestamp = datetime.now().strftime("%m_%d_%Y-%H:%M:%S")
# lrqg_exec_time = exec_time(t_lrqg_o, t_lrqg_t)
# qgstory_exec_time = exec_time(t_qgstory_o, t_qgstory_t)
# str8_exec_time = exec_time(t_str8_t, t_str8_t)
qgstory_exec_time = exec_time(t_qgstory_o, t_qgstory_t)



FILENAME = f"{IMG_ID}_{timestamp}.txt" 
LR_RESULT_PATH = LR_RESULT_PARENT_PATH + FILENAME
QG_RESULT_PATH = QG_RESULT_PARENT_PATH + FILENAME
STORY_RESULT_PATH = STORY_RESULT_PARENT_PATH + FILENAME
QGSTORY_RESULT_PATH = QGSTORY_RESULT_PARENT_PATH + FILENAME
STR8_RESULT_PATH = STR8_RESULT_PARENT_PATH + FILENAME


# with open(LR_RESULT_PATH,"w") as f:
#     f.write(lr_out)
# with open(QG_RESULT_PATH,"w") as f:
#     f.write(qg_out)
#     f.write("\n\n")
#     f.write(lrqg_exec_time)
# with open(STORY_RESULT_PATH,"w") as f:
#     f.write(story_out)
#     f.write("\n\n")
#     f.write(lrqg_exec_time)
# with open(QGSTORY_RESULT_PATH,"w") as f:
#     f.write(qgstory_out)
#     f.write("\n\n")
#     f.write(qgstory_exec_time)
# with open(STR8_RESULT_PATH,"w") as f:
#     f.write(str8_out)
#     f.write("\n\n")
#     f.write(str8_exec_time)

with open("sample_0012_qgstory.txt","w") as f:
    f.write(qgstory_out)
    f.write("\n\n")
    f.write(qgstory_exec_time)

In [27]:
story_out

'The image depicts a bustling city scene with a large group of people gathered around a food truck. The food truck is parked in the middle of the scene, and several people are standing around it, likely waiting to order or enjoying their meals. \n\nIn addition to the food truck, there are multiple traffic lights scattered throughout the scene, indicating that the area is well-regulated for vehicular and pedestrian traffic. A few cars can be seen in the background, further emphasizing the urban setting.\n\nThere are also a few bicycles parked or being ridden in the area, adding to the lively atmosphere. A backpack is visible on the ground, possibly belonging to one of the people in the scene. Overall, the image captures a vibrant city environment with people enjoying their time and engaging in various activities.'

In [29]:
from math import ceil
import random
from tqdm import tqdm
import numpy as np
from datetime import datetime

def batch_inference(image_path, model, processor, total_pair_count, pair_per_batch = 10):
    raw_image = Image.open(image_path)

    NUM_BATCH = ceil(total_pair_count / pair_per_batch)
    LAST_BATCH = total_pair_count % pair_per_batch
    
    total_out = ""
    
    for batch in tqdm(range(NUM_BATCH)):
        if batch == NUM_BATCH - 1:
            pair_per_batch = LAST_BATCH

        story_out = inference_llava(
            model, processor,
            STORY_PROMPT,
            raw_image
            
        )

        qgstory_out = inference_llava(
            model, processor,
            QGSTORY_PROMPT.format(story = story_out, number = pair_per_batch),
            raw_image
        )
        
        total_out += qgstory_out + "\n"
        
        print(f"BATCH-{batch}")
        print(qgstory_out)
        
        SEED = batch
        set_seed(SEED)

    timestamp = datetime.now().strftime("%m_%d_%Y-%H:%M:%S")
    FILENAME = f"{IMG_ID}_{timestamp}.txt" 
    QGSTORY_RESULT_PATH = QGSTORY_RESULT_PARENT_PATH + FILENAME
    SEED = config["seed"]
    
    print("FINAL")
    print(total_out)
    
    with open(QGSTORY_RESULT_PATH,"w") as f:
        f.write(total_out)
        f.write("\n\n")

batch_inference(IMG_PATH, model, processor, 15)

 50%|█████     | 1/2 [00:25<00:25, 25.18s/it]

BATCH-0
1. What is the main attraction in the image?
S. Food truck
L. The main attraction in the image is the food truck, which has drawn a large crowd of people who are standing around it, waiting to order or enjoying their meals.

2. Where is the food truck parked?
S. In the middle of the scene
L. The food truck is parked in the middle of the scene, surrounded by a large group of people.

3. Who are the people in the image?
S. Pedestrians
L. The people in the image are pedestrians, who are gathered around the food truck, likely waiting to order or enjoying their meals.

4. How many traffic lights are visible in the scene?
S. 5
L. There are five traffic lights visible in the scene, indicating that the area is well-regulated for vehicular and pedestrian traffic.

5. What are the people in the image doing?
S. Standing around food truck
L. The people in the image are standing around the food truck, likely waiting to order or enjoying their meals.

6. What is the purpose of the traffic li

100%|██████████| 2/2 [00:40<00:00, 20.32s/it]

BATCH-1
1. What is the main attraction in the image?
S. Food truck
L. The main attraction in the image is the food truck, which has drawn a large crowd of people who are standing around it, waiting to order or enjoying their meals.

2. Where is the food truck parked?
S. In the middle of the scene
L. The food truck is parked in the middle of the scene, surrounded by a large group of people.

3. Who are the people in the image?
S. Pedestrians
L. The people in the image are pedestrians, who are gathered around the food truck, likely enjoying their time and engaging in various activities.

4. How many traffic lights are visible in the scene?
S. 5
L. There are five traffic lights visible in the scene, indicating that the area is well-regulated for vehicular and pedestrian traffic.

5. Is there any bicycle in the image?
S. Yes
L. Yes, there are a few bicycles parked or being ridden in the area, adding to the lively atmosphere.
FINAL
1. What is the main attraction in the image?
S. Food truck





In [14]:
total_out

NameError: name 'total_out' is not defined

In [None]:
from datetime import datetime


raw_image = Image.open(IMG_PATH)


t_lrqg_o = datetime.now()
lr_out = inference_llava(
    model, processor, 
    LR_PROMPT.format(number = PAIR_NUM), 
    raw_image
)

qg_out = inference_llava(
    model, processor,
    QG_PROMPT.format(desc = lr_out, number = PAIR_NUM),
    raw_image
)
t_lrqg_t = datetime.now()



t_str8_o = datetime.now()
str8_out = inference_llava(
    model, processor,
    STR8_PROMPT.format(number = PAIR_NUM),
    raw_image
)
t_str8_t = datetime.now()



t_qgstory_o = datetime.now()
story_out = inference_llava(
    model, processor,
    STORY_PROMPT,
    raw_image
)

qgstory_out = inference_llava(
    model, processor,
    QGSTORY_PROMPT.format(story = story_out, number = PAIR_NUM),
    raw_image
)
t_qgstory_t = datetime.now()

timestamp = datetime.now().strftime("%m_%d_%Y-%H:%M:%S")
lrqg_exec_time = exec_time(t_lrqg_o, t_lrqg_t)
qgstory_exec_time = exec_time(t_qgstory_o, t_qgstory_t)
str8_exec_time = exec_time(t_str8_t, t_str8_t)


FILENAME = f"{IMG_ID}_{timestamp}.txt" 
LR_RESULT_PATH = LR_RESULT_PARENT_PATH + FILENAME
QG_RESULT_PATH = QG_RESULT_PARENT_PATH + FILENAME
STORY_RESULT_PATH = STORY_RESULT_PARENT_PATH + FILENAME
QGSTORY_RESULT_PATH = QGSTORY_RESULT_PARENT_PATH + FILENAME
STR8_RESULT_PATH = STR8_RESULT_PARENT_PATH + FILENAME


with open(LR_RESULT_PATH,"w") as f:
    f.write(lr_out)
with open(QG_RESULT_PATH,"w") as f:
    f.write(qg_out)
    f.write("\n\n")
    f.write(lrqg_exec_time)
with open(STORY_RESULT_PATH,"w") as f:
    f.write(story_out)
    f.write("\n\n")
    f.write(lrqg_exec_time)
with open(QGSTORY_RESULT_PATH,"w") as f:
    f.write(qgstory_out)
    f.write("\n\n")
    f.write(qgstory_exec_time)
with open(STR8_RESULT_PATH,"w") as f:
    f.write(str8_out)
    f.write("\n\n")
    f.write(str8_exec_time)

In [7]:
PAIR_NUM

30

In [20]:
import re

total_data = []

for num in ["009", "007", "005", "002", "001"]:
    with open(f"sample_{num}_qgstory.txt","r") as file:
        # Read the entire file content
        input_string = file.read()

    # Define regular expression patterns
    pattern_question = re.compile(r'\d+\.\s(.+?)\n')
    pattern_short_answer = re.compile(r'S\.\s(.+?)\n')
    pattern_long_answer = re.compile(r'L\.\s(.+?)\n')

    # Find matches using regular expressions
    questions = pattern_question.findall(input_string)
    short_answers = pattern_short_answer.findall(input_string)
    long_answers = pattern_long_answer.findall(input_string)


    # Zip the results into a list of JSON objects
    data = [
        {"id":f"sample_{num}","question": q, "short_answer": sa, "reasoned_answer": la}
        for q, sa, la in zip(questions, short_answers, long_answers)
    ]
    
    total_data += data

print(len(total_data))
with open(f"sdg_out.json", 'w') as json_file:
    json.dump(total_data, json_file, indent=2)

75
