In [None]:
!pip install einops flash_attn

In [None]:
from transformers import AutoProcessor, AutoModelForCausalLM  
from PIL import Image
import requests
import copy
import torch
%matplotlib inline  

In [None]:
model_id = 'microsoft/Florence-2-base'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype='auto').eval().cuda()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

In [None]:
def run_example(task_prompt, text_input=None):
    prompt = task_prompt if text_input is None else task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.float16)
    generated_ids = model.generate(
      input_ids=inputs["input_ids"].cuda(),
      pixel_values=inputs["pixel_values"].cuda(),
      max_new_tokens=1024,
      early_stopping=False,
      do_sample=False,
      num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text, 
        task=task_prompt, 
        image_size=(image.width, image.height)
    )

    return parsed_answer

In [None]:
import os
import glob
from tqdm import tqdm

keyframes_dir = '/kaggle/input/keyframe-extra-aic2024/Keyframes-extra'
all_keyframe_paths = dict()
for part in sorted(os.listdir(keyframes_dir)):
    parts = part.split('_')
    data_part = parts[-2] + "_" + parts[-1] if len(parts) == 3 else parts[-1]
    all_keyframe_paths[data_part] =  dict()

for data_part in sorted(all_keyframe_paths.keys()):
    data_part_path = f'{keyframes_dir}/Keyframes_{data_part}/keyframes'
    video_dirs = sorted(os.listdir(data_part_path))
    video_ids = [video_dir.split('_')[-1] for video_dir in video_dirs]
    for video_id, video_dir in zip(video_ids, video_dirs):
        keyframe_paths = sorted(glob.glob(f'{data_part_path}/{video_dir}/*.jpg'))
        all_keyframe_paths[data_part][video_id] = keyframe_paths

In [None]:
task_prompt = '<DETAILED_CAPTION>'
for key, video_keyframe_paths in tqdm(list(all_keyframe_paths.items())):
    video_ids = sorted(video_keyframe_paths.keys())
    
    directory_path = f"/kaggle/working/caption_encoded_extra/{key}/"
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
    for video_id in tqdm(video_ids):
        print(video_id)
        video_keyframe_path = video_keyframe_paths[video_id]
        caption = []
        for i in tqdm(range(0, len(video_keyframe_path))): 
            image_path = video_keyframe_path[i] 
            image = Image.open(image_path)
            res = run_example(task_prompt)['<DETAILED_CAPTION>'] 
            caption.append(res)
        # Saving the video context txt 
        with open(f"/kaggle/working/caption_encoded_extra/{key}/{video_id}.txt", "w") as f:
            for item in caption:
                f.write("%s\n" % item)  
        

print("Hoàn thành")