In [None]:
!pip install salesforce-lavis

In [8]:
# Import module
import os
import glob
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from lavis.models import load_model_and_preprocess

In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, vis_processors, _ = load_model_and_preprocess(name="blip_feature_extractor", model_type="base", is_eval=True, device=device)

# Parse data path

In [2]:
keyframes_dir_list = [f'/kaggle/input/{x}/Keyframes' for x in os.listdir('/kaggle/input')]
all_keyframe_paths = dict()
for keyframe_dir in keyframes_dir_list:
    for part in sorted(os.listdir(keyframe_dir)):
        data_part = part.split('_')[-2] # L01, L02 for ex
        all_keyframe_paths[data_part] =  dict()
for keyframe_dir in keyframes_dir_list:
    for data_part in sorted(all_keyframe_paths.keys()):
        data_part_path = f'{keyframe_dir}/{data_part}_extra'
        if os.path.isdir(data_part_path):
            video_dirs = sorted(os.listdir(data_part_path))
            video_ids = [video_dir.split('_')[-1] for video_dir in video_dirs]
            for video_id, video_dir in zip(video_ids, video_dirs):
                keyframe_paths = sorted(glob.glob(f'{data_part_path}/{video_dir}/*.jpg'))
                all_keyframe_paths[data_part][video_id] = keyframe_paths

# Model

In [None]:
bs = 256
save_dir = './BLIP_features'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

for key, video_keyframe_paths in tqdm(all_keyframe_paths.items()):
    video_ids = sorted(video_keyframe_paths.keys())
    
    if not os.path.exists(os.path.join(save_dir, key)):
        os.mkdir(os.path.join(save_dir, key))
    
    for video_id in tqdm(video_ids):
        video_feats = []
        video_keyframe_path = video_keyframe_paths[video_id]
        for i in range(0, len(video_keyframe_path), bs):
            # Support batchsize inferencing
            images = []
            image_paths = video_keyframe_path[i:i+bs]
            for image_path in image_paths:
                image = vis_processors["eval"](Image.open(image_path).convert("RGB")).unsqueeze(0)
                images.append(image)
            images = torch.cat(images).to(device)
            sample = {"image": images}
            with torch.no_grad():
                image_feats = model.extract_features(sample, mode="image").image_embeds_proj[:,0,:]

            for b in range(image_feats.shape[0]):
                video_feats.append(image_feats[b].detach().cpu().numpy().astype(np.float32).flatten())
        
        np.save(f'{save_dir}/{key}/{video_id}.npy', video_feats)