In [None]:
import os
import glob
import torch
import numpy as np
from PIL import Image
from lavis.models import load_model_and_preprocess
from tqdm import tqdm

class FeatureExtractor:
    def __init__(self, model_name="blip2_feature_extractor", model_type="pretrain", keyframes_dir=None, device=None):
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        
        # Initialize model and preprocessor
        self.model, self.vis_processors, self.txt_processors = load_model_and_preprocess(
            name=model_name, 
            model_type=model_type, 
            is_eval=True, 
            device=self.device
        )

        self.keyframes_dir = keyframes_dir
        self.keyframe_paths = self._collect_keyframe_paths()

    def _collect_keyframe_paths(self):
        """
        Collects all keyframe paths from the specified directory.
        """
        keyframe_paths = dict()
        for part in sorted(os.listdir(self.keyframes_dir)):
            keyframe_paths[part] = dict()

        for part in sorted(keyframe_paths.keys()):
            part_path = os.path.join(self.keyframes_dir, part)
            video_dirs = sorted(os.listdir(part_path))
            video_ids = [video_dir.split('_')[-1] for video_dir in video_dirs]
            for video_id, video_dir in zip(video_ids, video_dirs):
                paths = sorted(glob.glob(f'{part_path}/{video_dir}/*.jpg'))
                keyframe_paths[part][video_id] = paths

        return keyframe_paths

    def preprocess_images(self, image_paths):
        return [self.vis_processors["eval"](Image.open(path).convert("RGB")).unsqueeze(0) for path in image_paths]

    def extract_features(self, save_dir, batch_size=4):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for part, videos in self.keyframe_paths.items():
            part_save_dir = os.path.join(save_dir, part)
            if not os.path.exists(part_save_dir):
                os.makedirs(part_save_dir)

            for video_id, image_paths in tqdm(videos.items(), desc=f"Processing {part}"):
                video_features = []
                for i in range(0, len(image_paths), batch_size):
                    batch_paths = image_paths[i:i+batch_size]
                    images = self.preprocess_images(batch_paths)
                    
                    with torch.no_grad(), torch.cuda.amp.autocast():
                        features = []
                        for image in images:
                            # Extract features using BLIP-2
                            sample = {"image": image.to(self.device)}
                            feature = self.model.extract_features(sample, mode="image").image_embeds[0, 0, :]
#                             feature = self.model.extract_features(sample, mode="text").image_embeds[0, 0, :] # Extract text
                            features.append(feature)
                        
                        features = torch.stack(features)
                        features = torch.nn.functional.normalize(features, dim=-1)
                    
                    video_features.append(features.cpu().numpy())

                video_features = np.vstack(video_features)
                
                # Save features for this video
                save_path = os.path.join(part_save_dir, f"{video_id}.npy")
                np.save(save_path, video_features)

        print("Feature extraction and saving completed.")
