In [1]:
import os
import glob
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from open_clip import create_model_from_pretrained
from tqdm import tqdm

class FeatureExtractor:
    def __init__(self, model_name="hf-hub:timm/ViT-SO400M-14-SigLIP-384", keyframes_dir=None, device=None):
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        
        # Initialize model and preprocessor
        self.model, self.preprocess = create_model_from_pretrained(model_name)
        self.model = self.model.to(self.device)
        self.model.eval()

        self.keyframes_dir = keyframes_dir
        self.keyframe_paths = self._collect_keyframe_paths()

    def _collect_keyframe_paths(self):
        """
        Collects all keyframe paths from the specified directory.

        Output:
            - A dictionary where each key corresponds to a part (e.g., 'part_x') and
              each value is another dictionary mapping video IDs to a list of keyframe paths.

        Example Output Format:
            {
                'part_1': {
                    'video_001': ['/path/to/keyframe1.jpg', '/path/to/keyframe2.jpg', ...],
                    'video_002': [...],
                    ...
                },
                'part_2': {...},
                ...
            }
        """
        keyframe_paths = dict()
        for part in sorted(os.listdir(self.keyframes_dir)):
            keyframe_paths[part] = dict()

        for part in sorted(keyframe_paths.keys()):
            part_path = os.path.join(self.keyframes_dir, part)
            video_dirs = sorted(os.listdir(part_path))
            video_ids = [video_dir.split('_')[-1] for video_dir in video_dirs]
            for video_id, video_dir in zip(video_ids, video_dirs):
                paths = sorted(glob.glob(f'{part_path}/{video_dir}/*.jpg'))
                keyframe_paths[part][video_id] = paths

        return keyframe_paths

    def preprocess_images(self, image_paths):
        images = [self.preprocess(Image.open(path).convert("RGB")).unsqueeze(0) for path in image_paths]
        return torch.cat(images, dim=0).to(self.device)

    def extract_features(self, save_dir, batch_size=4):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for part, videos in self.keyframe_paths.items():
            part_save_dir = os.path.join(save_dir, part)
            if not os.path.exists(part_save_dir):
                os.makedirs(part_save_dir)

            for video_id, image_paths in tqdm(videos.items(), desc=f"Processing {part}"):
                video_features = []
                for i in range(0, len(image_paths), batch_size):
                    batch_paths = image_paths[i:i+batch_size]
                    images = self.preprocess_images(batch_paths)
                    
                    with torch.no_grad(), torch.cuda.amp.autocast():
                        features = self.model.encode_image(images)
                        features = F.normalize(features, dim=-1)
                    
                    video_features.append(features.cpu().numpy())

                video_features = np.vstack(video_features)
                
                # Save features for this video
                save_path = os.path.join(part_save_dir, f"{video_id}.npy")
                np.save(save_path, video_features)

        print("Feature extraction and saving completed.")


   


In [2]:
 # Specify the directories
keyframes_dir = '/media/daoan/T7 Shield2/AI_Challenge_2024_DATA/Keyframes'
save_dir = '/media/daoan/T7 Shield2/AI_Challenge_2024_DATA/SIG_CLIP_features'
extractor = FeatureExtractor(keyframes_dir=keyframes_dir)
    

extractor.extract_features(save_dir, batch_size=16)

print("Feature extraction process completed.")
print(f"Features saved in: {save_dir}")

Using device: cuda


Processing L01_extra: 100%|██████████| 31/31 [11:53<00:00, 23.01s/it]
Processing L02_extra: 100%|██████████| 31/31 [12:28<00:00, 24.15s/it]
Processing L03_extra: 100%|██████████| 30/30 [11:31<00:00, 23.05s/it]
Processing L04_extra: 100%|██████████| 30/30 [12:33<00:00, 25.11s/it]
Processing L05_extra: 100%|██████████| 31/31 [12:16<00:00, 23.76s/it]
Processing L06_extra: 100%|██████████| 31/31 [13:56<00:00, 26.98s/it]
Processing L07_extra: 100%|██████████| 31/31 [12:41<00:00, 24.58s/it]
Processing L08_extra: 100%|██████████| 30/30 [12:58<00:00, 25.94s/it]
Processing L09_extra: 100%|██████████| 29/29 [12:32<00:00, 25.93s/it]
Processing L10_extra: 100%|██████████| 29/29 [13:11<00:00, 27.28s/it]
Processing L11_extra: 100%|██████████| 30/30 [12:00<00:00, 24.03s/it]
Processing L12_extra: 100%|██████████| 30/30 [13:02<00:00, 26.08s/it]
Processing L13_extra: 100%|██████████| 30/30 [11:23<00:00, 22.79s/it]
Processing L14_extra: 100%|██████████| 27/27 [11:34<00:00, 25.74s/it]
Processing L15_extra

Feature extraction and saving completed.
Feature extraction process completed.
Features saved in: /media/daoan/T7 Shield2/AI_Challenge_2024_DATA/SIG_CLIP_features



