In [1]:
import cv2
from PIL import Image
from pillow_heif import register_heif_opener
register_heif_opener()

from similarities import ClipSimilarity, SiftSimilarity
import networkx as nx

from tqdm import tqdm
import glob
import os
import fnmatch
import pprint
import numpy as np
import pickle
import subprocess

from sklearn.cluster import AgglomerativeClustering

In [2]:
class SingletonModelLoader:
    _instances = {}

    def __new__(cls, model_name_or_path):
        if model_name_or_path not in cls._instances:
            instance = super(SingletonModelLoader, cls).__new__(cls)
            instance.model = ClipSimilarity(model_name_or_path=model_name_or_path)
            cls._instances[model_name_or_path] = instance
        return cls._instances[model_name_or_path]

    @classmethod
    def get_model(cls, model_name_or_path):
        return cls(model_name_or_path).model

In [3]:
import os
import pickle
from PIL import Image

class CacheManager:

    ROOT_BASE = '~/.cache/ai_album/'
    
    def __init__(self, cache_path_prefix, root_path, cache_tag, generate_func, format_str="{base}_{cache_tag}.cache"):
        self.cache_path_prefix = cache_path_prefix
        self.root_path = os.path.abspath(root_path)
        self.cache_tag = cache_tag
        self.generate_func = generate_func
        self.format_str = format_str

    def _get_cache_file_path(self, path):

        root_p, root_folder_name = os.path.split(self.root_path.rstrip('/'))
        cache_path = os.path.abspath(path).replace(root_p, self.ROOT_BASE + self.cache_path_prefix)

        base, ext = os.path.splitext(cache_path)
        ext = ext[1:]
        p = self.format_str.format(base=base, ext=ext, cache_tag=self.cache_tag)
        
        return os.path.expanduser(p).lower()

    def load(self, path):
        def save(data, path):
            if isinstance(data, Image.Image):
                data.save(path, quality=50)
            elif isinstance(data, str):
                with open(path, 'w', encoding='utf-8') as file:
                    file.write(data)
            else:
                with open(path, 'wb') as file:
                    pickle.dump(data, file)

        def load_individual_file(path):
            if path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')):
                return Image.open(path)
            elif path.lower().endswith('.txt'):
                with open(path, 'r', encoding='utf-8') as file:
                    return file.read()
            else:
                with open(path, 'rb') as file:
                    return pickle.load(file)

        cache_file_path = self._get_cache_file_path(path)

        # Check if cache file exists or if it matches any files when wildcard is present
        if '*' in cache_file_path:
            matched_files = glob.glob(cache_file_path)
            if not matched_files:
                os.makedirs(os.path.dirname(cache_file_path), exist_ok=True)
                i = 0
                for item in self.generate_func(path):
                    i += 1
                    item_path = cache_file_path.replace('*', str(i))
                    save(item, item_path)
                    
            matched_files = glob.glob(cache_file_path)
            return [load_individual_file(file_path) for file_path in sorted(matched_files)]
        else:
            if not os.path.exists(cache_file_path):
                data = self.generate_func(path)
                os.makedirs(os.path.dirname(cache_file_path), exist_ok=True)
                save(data, cache_file_path)

            if os.path.exists(cache_file_path):
                return load_individual_file(cache_file_path)

        return None




In [4]:
INTERVAL = 10

def format_filename(file_path):
    """Truncate and format the file name to a maximum of 30 characters."""
    file_name = os.path.basename(file_path)
    return (file_name[:27] + '...') if len(file_name) > 30 else file_name

class VideoManager:
    def __init__(self, folder_path, model_name='OFA-Sys/chinese-clip-vit-huge-patch14'):
        self.folder_path = folder_path
        self.model_name = model_name.replace('/', '_')
        self.similarity_model = SingletonModelLoader.get_model(model_name)
        self.frame_cache_manager = CacheManager(cache_path_prefix=".similarity_cache/video/",
                                                root_path=folder_path,
                                                cache_tag="frames",
                                                generate_func=self._extract_and_cache_frames,
                                                format_str='{base}/thumbnail_{cache_tag}_*.jpg')
        self.emb_cache_manager = CacheManager(cache_path_prefix=".similarity_cache/video/",
                                              root_path=folder_path,
                                              cache_tag="emb",
                                              generate_func=self._generate_embeddings,
                                              format_str='{base}/sim_{cache_tag}_*.emb')

    def extract_key_frame(self, path):
        file_name = format_filename(path)
        print(f"Extracting frames from video '{file_name}'...")

        embeddings = self.emb_cache_manager.load(path)

        max_avg_similarity = 0
        key_frame = None
        pil_frames = self.frame_cache_manager.load(path)

        for i, emb_a in enumerate(embeddings):
            total_similarity = 0
            for j, emb_b in enumerate(embeddings):
                if i != j:
                    similarity = self.similarity_model.score_functions['cos_sim'](emb_a, emb_b)
                    total_similarity += similarity

            avg_similarity = total_similarity / (len(embeddings) - 1)
            if avg_similarity > max_avg_similarity:
                max_avg_similarity = avg_similarity
                key_frame = pil_frames[i]

        return key_frame

    def extract_frames(self, video_path):
        return self.frame_cache_manager.load(video_path)

    def _generate_embeddings(self, video_path):
        frames = self.extract_frames(video_path)
        batch_embeddings = self.similarity_model.get_embeddings(frames, show_progress_bar=True, batch_size=8)
        return batch_embeddings

    def _extract_and_cache_frames(self, video_path):
        cache_dir = self._create_cache_directory(video_path)
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_interval = int(fps * INTERVAL)

        frame_count = 0
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            if frame_count % frame_interval == 0 or frame_count == total_frames - 1:
                frame_path = os.path.join(cache_dir, f"frame_{frame_count}.jpg")
                pil_image = self._cv_frame_to_pil_image(frame)
                pil_image.thumbnail((1280, 720))
                yield pil_image

            frame_count += 1

        cap.release()

    def _create_cache_directory(self, video_path):
        cache_dir = os.path.splitext(os.path.basename(video_path))[0]
        cache_dir = f".similarity_cache/video/{cache_dir}_cache"
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)
        return cache_dir

    @staticmethod
    def _cv_frame_to_pil_image(frame):
        cv2_image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(cv2_image_rgb)
        return pil_image

    

In [10]:
class ImageSimilarity:
    def __init__(self, folder_path, model_name='OFA-Sys/chinese-clip-vit-huge-patch14', batch_size=8, show_progress_bar=True, **kwargs):
        print("Initializing ImageSimilarity...")

        self.model_name = model_name.replace('/', '_')  # Replace '/' in model name for file paths
        self.similarity_model = SingletonModelLoader.get_model(model_name)
        self.batch_size = batch_size
        self.show_progress_bar = show_progress_bar
        self.kwargs = kwargs  # Store any additional keyword arguments

        self.folder_path = folder_path
        self.media_fps = self._load_image_paths(folder_path)
        self.video_mng = VideoManager(folder_path)
        self.qa = ImageQuestionAnswerer("blip_caption", "large_coco")

        # Initialize CacheManagers
        self.thumbnail_cache_manager = CacheManager(cache_path_prefix=".similarity_cache/img/",
                                                    root_path=folder_path,
                                                    cache_tag="thumbnail",
                                                    generate_func=self._compute_and_save_thumbnail,
                                                    format_str="{base}_thumbnail.jpg")
        self.embedding_cache_manager = CacheManager(cache_path_prefix=".similarity_cache/img/",
                                                    root_path=folder_path,
                                                    cache_tag="emb",
                                                    generate_func=self._generate_embedding,
                                                    format_str="{base}.emb")
        self.caption_cache_manager = CacheManager(cache_path_prefix=".similarity_cache/img/",
                                                    root_path=folder_path,
                                                    cache_tag="caption",
                                                    generate_func=self._generate_caption,
                                                    format_str="{base}_caption.txt")

        print("Loaded similarity model and image file paths.")
        self._initialize()
        self.similarity_cache = self._cache_similarities()
        print("Initialization complete.")

    def _load_image_paths(self, folder_path):
        img_fps = sorted(os.path.join(root, f) for root, _, files in os.walk(folder_path) for f in files if self._is_image(f))
        vid_fps = sorted(os.path.join(root, f) for root, _, files in os.walk(folder_path) for f in files if self._is_video(f))

        return img_fps + vid_fps

    def _is_image(self, path):
        file_extensions = ['*.jpg', '*.jpeg', '*.png', '*.heic', '*.heif']
        return any(fnmatch.fnmatch(path.lower(), ext) for ext in file_extensions)
        
    def _is_video(self, path):
        file_extensions = ['*.mp4', '*.avi', '*.webm', '*.mkv', '*.mov']
        return any(fnmatch.fnmatch(path.lower(), ext) for ext in file_extensions)

    def _compute_and_save_thumbnail(self, image_path):
        def compute_thumbnail(path):
            if self._is_image(path):
                with Image.open(image_path) as img:
                    img.thumbnail((1280, 720))
                    return img

            if self._is_video(path):
                return self.video_mng.extract_key_frame(path)

            return None

        return compute_thumbnail(image_path)
        
    def _initialize(self):
        print("Initializing embeddings...")
        for fp in tqdm(self.media_fps, desc="Initializing embeddings"):
            _ = self.embedding_cache_manager.load(fp)
    
        print("Initializing captions...")
        for fp in tqdm(self.media_fps, desc="Initializing captions"):
            _ = self.caption_cache_manager.load(fp)

    def _generate_embedding(self, image_path):
        img = self.thumbnail_cache_manager.load(image_path)
        emb = self.similarity_model.get_embeddings([img])[0]  # Extract the first (and only) embedding
        return emb

    def _generate_caption(self, image_path):
        img = self.thumbnail_cache_manager.load(image_path)
        return self.qa.caption(img, max_length=90, min_length=30)[0]
    
    def _cache_similarities(self):
        cache = {}
        batch_size = 100

        print("Caching similarities...")
        pbar = tqdm(total=len(self.media_fps), desc="Caching similarities")
        for i in range(0, len(self.media_fps), batch_size):
            end_idx = min(i + batch_size, len(self.media_fps))
            batch_fps = self.media_fps[i:end_idx]

            # Load embeddings for the batch
            embeddings = [self.embedding_cache_manager.load(fp) for fp in batch_fps]

            # Compute similarity matrix for the batch
            for j in range(len(batch_fps)):
                for k in range(j + 1, len(batch_fps)):  # Avoid duplicate computations
                    sim_score = self.similarity_model.score_functions['cos_sim'](embeddings[j], embeddings[k])
                    cache[(i + j, i + k)] = sim_score
                    cache[(i + k, i + j)] = sim_score

            pbar.update(min(batch_size, len(self.media_fps) - i))

        pbar.close()
        return cache
        
    def get_similarity_with_file_path(self, file_path_a, file_path_b):
        if file_path_a in self.media_fps and file_path_b in self.media_fps:
            idx_a = self.media_fps.index(file_path_a)
            idx_b = self.media_fps.index(file_path_b)
            return self.similarity_cache.get((idx_a, idx_b), None)
        else:
            return None  # File path not found

    def cluster_images_with_multilevel_hierarchical(self, distance_levels=None):
        """
        Cluster embeddings in a multi-level hierarchy using a list of distance thresholds.

        :param embeddings: The embeddings to cluster.
        :param distance_levels: A list of distance thresholds for each level of clustering.
        :return: Nested dictionary representing multi-level hierarchical clusters.
        """
        embeddings = [self.embedding_cache_manager.load(fp) for fp in self.media_fps]
        
        if distance_levels is None:
            distance_levels = [0.05]  # Default value

        # Pair each embedding with its corresponding file path
        paired_data = list(zip(self.media_fps, embeddings))

        # Starting with all paired data as the initial cluster
        initial_cluster = {0: paired_data}

        # Function to recursively apply clustering
        def recursive_clustering(current_clusters, level):
            if level >= len(distance_levels):
                # At the final level, return the file paths instead of (file path, embedding) pairs
                return {cluster_id: [fp for fp, _ in cluster_data] for cluster_id, cluster_data in current_clusters.items()}

            new_clusters = {}
            for cluster_id, cluster_data in current_clusters.items():
                if len(cluster_data) > 1:
                    # Extract embeddings for clustering
                    cluster_embeddings = [emb for _, emb in cluster_data]
                    clustering = AgglomerativeClustering(distance_threshold=distance_levels[level], n_clusters=None)
                    clustering.fit(cluster_embeddings)

                    sub_clusters = {}
                    for idx, label in enumerate(clustering.labels_):
                        sub_clusters.setdefault(label, []).append(cluster_data[idx])

                    new_clusters[cluster_id] = recursive_clustering(sub_clusters, level + 1)
                else:
                    # If only one item in cluster, no need for further clustering
                    new_clusters[cluster_id] = cluster_data

            return new_clusters

        # Apply recursive clustering starting from level 0
        final_clusters = recursive_clustering(initial_cluster, 0)
        return final_clusters
            

    def cluster_images_with_hierarchical(self, embeddings, distance_threshold=0.05):
        '''The best. distance_threshold = 0.5 for detailed cluster. distance_threshold = 2 for coarse cluster'''
        # Convert list of embeddings to a numpy array
        embeddings_array = np.array(embeddings)

        # Apply Hierarchical Clustering
        hierarchical_cluster = AgglomerativeClustering(n_clusters=None, distance_threshold=distance_threshold, linkage='ward')
        hierarchical_cluster.fit(embeddings_array)

        # Extract cluster assignments
        labels = hierarchical_cluster.labels_

        # Group file paths by cluster labels
        clusters = {}
        for idx, label in enumerate(labels):
            clusters.setdefault(label, []).append(self.media_fps[idx])

        return clusters

In [None]:
# Usage
# folder_path = './data/YeonWooPinkBikini/'
folder_path = './data/blue_minidress/'
s = ImageSimilarity(folder_path,
                    model_name='OFA-Sys/chinese-clip-vit-huge-patch14',
                    # model_name='OFA-Sys/chinese-clip-vit-base-patch16',
                    batch_size=16,
                    show_progress_bar=True)
# file_path_a = './data/YeonWooPinkBikini/Yeon-Woo-Pink-Bikini-telegram[asiansts]-021.jpg'
# file_path_b = './data/YeonWooPinkBikini/Yeon-Woo-Pink-Bikini-telegram[asiansts]-031.jpg'
# print(s.get_similarity_with_file_path(file_path_a, file_path_b))  # Example usage

Initializing ImageSimilarity...
Loaded similarity model and image file paths.
Initializing embeddings...


Initializing embeddings: 100%|██████████████████████████████████████| 100/100 [00:00<00:00, 1344.83it/s]


Initializing captions...


Initializing captions:   3%|█▎                                          | 3/100 [01:49<57:28, 35.55s/it]

In [12]:
# clusters = s.cluster_images_with_hierarchical(distance_threshold=0.5)
clusters = s.cluster_images_with_multilevel_hierarchical(distance_levels=[2, 0.5])
pprint.pprint(clusters)

{0: {0: {0: ['./data/blue_minidress/IMG_2718.HEIC',
             './data/blue_minidress/IMG_2719.HEIC',
             './data/blue_minidress/IMG_2720.HEIC',
             './data/blue_minidress/IMG_2721.HEIC',
             './data/blue_minidress/IMG_2722.HEIC',
             './data/blue_minidress/IMG_2726.HEIC'],
         1: ['./data/blue_minidress/IMG_2661.HEIC',
             './data/blue_minidress/IMG_2662.HEIC',
             './data/blue_minidress/IMG_2663.HEIC',
             './data/blue_minidress/IMG_2664.HEIC',
             './data/blue_minidress/IMG_2665.HEIC',
             './data/blue_minidress/IMG_2666.HEIC',
             './data/blue_minidress/IMG_2667.HEIC',
             './data/blue_minidress/IMG_2668.HEIC',
             './data/blue_minidress/IMG_2669.HEIC',
             './data/blue_minidress/IMG_2670.HEIC'],
         2: ['./data/blue_minidress/IMG_2723.HEIC',
             './data/blue_minidress/IMG_2724.HEIC',
             './data/blue_minidress/IMG_2725.HEIC']},
     1: 

In [13]:
import os
import shutil

def copy_files_to_clusters(clusters, target_path, current_path=""):
    for cluster_id, contents in clusters.items():
        # Create a new subdirectory for the current cluster
        cluster_dir = os.path.join(target_path, current_path, str(cluster_id))
        os.makedirs(cluster_dir, exist_ok=True)

        if isinstance(contents, dict):
            # If the contents are a dictionary, recurse into it
            copy_files_to_clusters(contents, target_path, os.path.join(current_path, str(cluster_id)))
        elif isinstance(contents, list):
            # If the contents are a list, copy the files into the current cluster directory
            for file_path in contents:
                shutil.copy(file_path, cluster_dir)

# Example usage
# Assuming 'clusters' is your multi-level hierarchical cluster dictionary
copy_files_to_clusters(clusters, './data/testoutput/')



## test

In [53]:
questions = {
    "explicit_content": "Does this photo contain explicit content?",
    "cloth_color": "which color or colors or nude does the girl wear?",
    "camera_angle": "which angle does the photo take from according to the girl?",
    "pose": "which pose does the girl perform?",
    "is_sex": "is the main character in the picture having sex currently",
    "is_blowjob": "is it doing a oral sex?",
    "is_doggy": "is it having sex in doggy style from behind?",
    "is_virginia_penetration": "is there virginia penetrated by penis?",
    "pose_hand": "what is the pose of her hand",
    "pose_thigh": "what is the pose of her thigh",
    "pose_crotch": "what is the pose of her crotch",
    "pose_leg": "what is the pose of her leg",
    "name": "give a informative filename for this photo that make it unique",
    # Add other keywords and questions as needed
}

In [57]:
# Usage example
# qa = ImageQuestionAnswerer("blip_vqa", "vqav2")
qa = ImageQuestionAnswerer("blip_caption", "large_coco")

# qa = ImageQuestionAnswerer("blip2_opt", "caption_coco_opt2.7b")

In [50]:
# img = s.thumbnail_cache_manager.load('./data/blue_minidress/IMG_2050.HEIC')
# img = s.thumbnail_cache_manager.load('./data/blue_minidress/IMG_2661.HEIC')
# img = s.thumbnail_cache_manager.load('./data/blue_minidress/IMG_2066.HEIC')
# img = s.thumbnail_cache_manager.load('./data/blue_minidress/IMG_2840.HEIC')
# img = s.thumbnail_cache_manager.load('./data/blue_minidress/IMG_2726.HEIC')
img = s.thumbnail_cache_manager.load('./data/blue_minidress/IMG_2077.MOV')
# img = s.thumbnail_cache_manager.load('./data/blue_minidress/IMG_2076.MOV')
# img = s.thumbnail_cache_manager.load('./data/blue_minidress/IMG_2007.HEIC')

In [55]:
for fp in s.media_fps:
    res = qa.caption(s.thumbnail_cache_manager.load(fp))
    print(fp, res)

./data/blue_minidress/IMG_2007.HEIC ['a woman in a gray dress standing in a living room']
./data/blue_minidress/IMG_2008.HEIC ['a woman sitting on a window sill in front of a window']
./data/blue_minidress/IMG_2009.HEIC ['a woman sitting on a window sill in front of a window']
./data/blue_minidress/IMG_2010.HEIC ['a woman is sitting on a window sill']
./data/blue_minidress/IMG_2011.HEIC ['a woman is sitting on a window sill']
./data/blue_minidress/IMG_2012.HEIC ['a woman sitting on a window sill in a room']
./data/blue_minidress/IMG_2013.HEIC ['a woman is sitting on a window sill']
./data/blue_minidress/IMG_2014.HEIC ['a woman is sitting on a window sill']


KeyboardInterrupt: 

In [58]:
for fp in s.media_fps:
    res = qa.caption(s.thumbnail_cache_manager.load(fp), max_length=90, min_length=30)
    print(fp, res)

./data/blue_minidress/IMG_2007.HEIC ['a woman in a gray dress standing in a living room next to a white couch and a window with a view of a city outside']
./data/blue_minidress/IMG_2008.HEIC ['a woman sitting on a window sill in front of a window with a view of a snowy mountain outside of her window and a coffee table']
./data/blue_minidress/IMG_2009.HEIC ['a woman sitting on a window sill in front of a large window with a view of a city and a coffee table in the corner']
./data/blue_minidress/IMG_2010.HEIC ['a woman sitting on a window sill in a room with a view of a city and a coffee table and a window sill']
./data/blue_minidress/IMG_2011.HEIC ['a woman sitting on a window sill in a room with a view of a mountain outside of the window and a pair of shoes on the window sill']
./data/blue_minidress/IMG_2012.HEIC ['a woman sitting on a window sill next to a coffee table and a pair of shoes on the floor in front of a window']
./data/blue_minidress/IMG_2013.HEIC ['a woman sitting on a wi


KeyboardInterrupt



In [51]:
%%time
for k, v in questions.items():
    res = qa.ask(img, v)
    print(v, res)

Does this photo contain explicit content? ['yes']
which color or colors or nude does the girl wear? ['gray']
which angle does the photo take from according to the girl? ['straight']
which pose does the girl perform? ['back']
is the main character in the picture having sex currently ['yes']
is it doing a oral sex? ['no']
is it having sex in doggy style from behind? ['yes']
is there virginia penetrated by penis? ['yes']
what is the pose of her hand ['on her butt']
what is the pose of her thigh ['bent']
what is the pose of her crotch ['bent']
what is the pose of her leg ['bent']
give a informative filename for this photo that make it unique ['naked woman']
CPU times: user 40.6 s, sys: 3.32 s, total: 43.9 s
Wall time: 41 s


In [23]:
%%time
# Ask a question about an image
tmp = [(k,v) for k, v in questions.items()]
res = qa.asks(img, [k for k, v in tmp])
for (_, q), a in zip(tmp, res):
    print(q, a)


Does this photo contain explicit content? yes
which color or colors or nude does the girl wear? brown
which angle does the photo take from according to the girl? sideways
which pose does the girl perform? woman is laying down
is the main character in the picture having sex currently yes
is it doing a blowjob? yes
is it having sex in doggy style position? yes
CPU times: user 22 s, sys: 5.67 s, total: 27.7 s
Wall time: 22.5 s


In [63]:
qa.rank(img, 'which angle is the photo taken from based on the girl', 
        [
            'Front (0 degree)',
            'Side Profile (90 Degrees)',
            'Back (180 Degrees)',
            'Slight Side Back (150 Degrees)',
            'Three-Quarter View (45 Degrees)',
            'Slight Side Angle (15-30 Degrees)',
        ])

['Back (180 Degrees)']

In [66]:
qa.rank(img, 'which angle is the photo taken from based on the girl', 
        [
            "eye level",
            "high angle",
            "low angle",
            "bird's eye view",
            "worm's eye view",
            "over the shoulder shot",
            "close up",
            "frog view",
        ])

['low angle']