In [1]:
!pip install -q git+https://github.com/openai/CLIP.git

In [2]:
import torch
import clip
from sklearn.metrics import silhouette_score
from scipy.spatial.distance import cdist
import numpy as np
import cv2
import pandas as pd
from PIL import Image
from glob import glob
import os
from tqdm import tqdm

In [3]:
vid_root = '/kaggle/input/hcmc-aic2024/video'
vid_insights_path = '/kaggle/input/hcmc-aic2024/vid_insights.csv'

vid_insights = pd.read_csv(vid_insights_path, sep='-')
def get_vid_path(vid_name):
    return os.path.join(vid_root, vid_name + '.mp4')
def get_vid_shots_undone(vid_name):
    vid_shots = vid_insights[(vid_insights['vid_name'] == vid_name) & (vid_insights['keyframes'] == '[]')].copy()
    return vid_shots[['start', 'end']].to_numpy()
def get_vid_shots(vid_name):
    vid_shots = vid_insights[vid_insights['vid_name'] == vid_name].copy()
    return vid_shots[['start', 'end']].to_numpy()

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

# K-means init

In [5]:
@torch.jit.script
def calculate_distances(data: torch.Tensor, centers: torch.Tensor) -> torch.Tensor:
    return torch.cdist(data, centers)

In [6]:
@torch.jit.script
def assign_clusters(distances: torch.Tensor) -> torch.Tensor:
    return torch.argmin(distances, dim=1)

In [7]:
@torch.jit.script
def calculate_sse(data: torch.Tensor, centers: torch.Tensor, cluster_labels: torch.Tensor) -> float:
    sse = torch.sum(torch.stack([
        torch.sum((data[cluster_labels == i] - centers[i]) ** 2)
        for i in range(len(centers))
    ]))
    return sse.item()

In [8]:
def initialize_kmeans_centers(data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    print("Initializing centers ...", end=' - ')
    n = len(data)
    sqrt_n = int(torch.sqrt(torch.tensor(n)))
    centers = []

    while len(centers) < sqrt_n:
        sse_min = float('inf')
        new_center = None
        final_labels = None

        for i in range(n):
            if any(torch.equal(data[i], center) for center in centers):
                continue

            temp_centers = torch.stack(centers + [data[i]])
            distances = calculate_distances(data, temp_centers)
            cluster_labels = assign_clusters(distances)
            sse = calculate_sse(data, temp_centers, cluster_labels)

            if sse < sse_min:
                sse_min = sse
                new_center = data[i]
                final_labels = cluster_labels.clone()

        centers.append(new_center)

    return final_labels, torch.stack(centers)

# K-means improvement

In [9]:
def find_nearest_clusters(distances):
    """
    Find the two closest clusters based on the distance matrix.
    """
    triu_indices = torch.triu_indices(len(distances), len(distances), offset=1)
    triu_distances = distances[triu_indices[0], triu_indices[1]]
    min_index = torch.argmin(triu_distances)
    return triu_indices[0][min_index].item(), triu_indices[1][min_index].item()

In [10]:
def merge_clusters(clusters, merge_indices):
    """
    Merge the two closest clusters and update the cluster labels.
    """
    merged_cluster = torch.where(clusters == merge_indices[1], merge_indices[0], clusters)
    return torch.where(merged_cluster > merge_indices[1], merged_cluster - 1, merged_cluster)

In [11]:
def update_cluster_centers(features, clusters, k):
    """
    Update the centers of the clusters after merging.
    """
    new_centers = []
    for cluster_id in range(k):
        cluster_samples = features[clusters == cluster_id]
        cluster_mean = torch.mean(cluster_samples, dim=0)
        distances = torch.norm(cluster_samples - cluster_mean, dim=1)
        closest_sample_index = torch.argmin(distances)
        new_centers.append(cluster_samples[closest_sample_index])
    return new_centers

In [12]:
def kmeans_silhouette(features):
    features = features.to(dtype=torch.float32)  # Ensure float32 dtype
    sqrt_n = int(torch.sqrt(torch.tensor(len(features))))
    k = sqrt_n
    best_k = k
    best_clusters = None
    best_avg_silhouette = -1

    clusters, centers = initialize_kmeans_centers(features)
    clusters = torch.as_tensor(clusters, device=DEVICE, dtype=torch.float32)
    centers = torch.as_tensor(centers, device=DEVICE, dtype=torch.float32)
    
    best_centers = centers.clone()

    while k > 2:
        distances = torch.cdist(centers, centers)
        merge_indices = find_nearest_clusters(distances)
        clusters = merge_clusters(clusters, merge_indices)
        centers = torch.stack(update_cluster_centers(features, clusters, k-1))

        k -= 1

        if len(torch.unique(clusters)) > 1:
            avg_silhouette = silhouette_score(features.cpu().numpy(), clusters.cpu().numpy())

            if avg_silhouette > best_avg_silhouette:
                best_avg_silhouette = avg_silhouette
                best_k = k
                best_clusters = clusters.clone()
                best_centers = centers.clone()
        else:
            break

    if best_clusters is None:
        best_clusters = clusters
        best_centers = centers

    center_indices = []
    for center in best_centers:
        distances = torch.norm(features - center, dim=1)
        center_indices.append(torch.argmin(distances).item())

    return best_clusters.cpu().numpy(), best_centers.cpu().numpy(), best_k, center_indices

# Redundancy

In [13]:
def color_histogram(img):
    hist = cv2.calcHist([img], [0, 1, 2], None, [8, 8, 8], [0, 255, 0, 255, 0, 255])
    return torch.from_numpy(hist.flatten())

In [14]:
def compute_color_histograms(video, keyframe_indices):
    histograms = []
    for frame_index in keyframe_indices:
        video.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
        ret, frame = video.read()
        if ret:
            histograms.append(color_histogram(frame))
    return torch.stack(histograms)

In [15]:
def filter_low_information_frames(histograms: torch.Tensor, keyframe_indices: list[int], threshold: int = 10) -> tuple[torch.Tensor, list[int]]:
    mask = torch.sum(histograms > 0, dim=1) > threshold
    return histograms[mask], [keyframe_indices[i] for i in range(len(keyframe_indices)) if mask[i]]

In [16]:
@torch.jit.script
def calculate_similarity_matrix(histograms: torch.Tensor) -> torch.Tensor:
    normalized_histograms = histograms / torch.norm(histograms, dim=1, keepdim=True)
    return torch.mm(normalized_histograms, normalized_histograms.t())

In [17]:
def remove_redundant_frames(similarity_matrix: torch.Tensor, keyframe_indices: list[int], threshold: float) -> list[int]:
    del_indices = set()
    for i in range(len(keyframe_indices)):
        for j in range(i + 1, len(keyframe_indices)):
            if similarity_matrix[i, j] > threshold:
                del_indices.add(keyframe_indices[j])
    final_indices = set(keyframe_indices) - del_indices
    return sorted(final_indices)

In [18]:
def redundancy(video, keyframe_indices: list[int], threshold: float = 0.94) -> list[int]:
    histograms = compute_color_histograms(video, keyframe_indices)
    filtered_histograms, filtered_indices = filter_low_information_frames(histograms, keyframe_indices)
    similarity_matrix = calculate_similarity_matrix(filtered_histograms)
    final_indices = remove_redundant_frames(similarity_matrix, filtered_indices, threshold)
    return final_indices

# Keyframe extraction

In [19]:
def keyframe_extraction(shot: tuple[int, int], features: np.ndarray, video, threshold: float = 0.94) -> list[int]:
    start, _ = shot
    features_tensor = torch.as_tensor(features, device=DEVICE, dtype=torch.float32)
    _, _, _, index = kmeans_silhouette(features_tensor)
    final_index = [start + i for i in index]
    keyframe_index = redundancy(video, final_index, threshold)
    return sorted(keyframe_index)

# Plot keyframes

In [20]:
# import matplotlib.pyplot as plt

# def get_frame(video, frame_id):
#     video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
#     _, frame = video.read()
    
#     return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

# def fml_frames(video, shot): # first, mid, last frame
#     first_id, last_id = shot
#     mid_id = (last_id - first_id) // 2 + first_id
    
#     first_frame = get_frame(video, first_id)
#     mid_frame = get_frame(video, mid_id)
#     last_frame = get_frame(video, last_id)
    
#     return [(first_frame, first_id), (mid_frame, mid_id), (last_frame, last_id)]

# def show_fml_keyframes(keyframes):
#     len_keyframes =  len(keyframes)
    
#     plt.figure(figsize=(10, 5))
#     for i in range(len_keyframes):
#         frame = keyframes[i]
#         plt.subplot(1, len_keyframes, i+1)
#         plt.imshow(frame[0])
#         plt.title(f'Frame {frame[1]}')
        
#     plt.show()

# def show_lmske_keyframes(video, keyframes):
#     len_keyframes =  len(keyframes)
    
#     plt.figure(figsize=(10, 5))
#     for i in range(len_keyframes):
#         frame_i = keyframes[i]
#         plt.subplot(1, len_keyframes, i+1)
#         plt.imshow(get_frame(video, frame_i))
#         plt.title(f'Frame {frame_i}')
        
#     plt.show()

# Inference

In [21]:
model, preprocess = clip.load("ViT-L/14")

100%|████████████████████████████████████████| 890M/890M [00:05<00:00, 169MiB/s]


In [25]:
# from time import time
# start = time()
# lst_vids = sorted(vid_insights['vid_name'].unique())
# lst_vids = ['L01_V002']

lst_vids = ['L10_V0'+ str(i).zfill(2) for i in range(1, 11)]

keyframes = {}

video = cv2.VideoCapture(vid_path)
    
for vid in lst_vids:
    vid_name = vid.split('/')[-1]
    vid_path = get_vid_path(vid_name)
    keyframes[vid] = []
    
#     vid_shots = get_vid_shots(vid_name)
    vid_shots = get_vid_shots_undone(vid_name)
    for i, bound in tqdm(enumerate(vid_shots)):
        print(bound, end=": ")
        features = []
        frames = []
        for frame in range(bound[0], bound[1]+1):
            video.set(cv2.CAP_PROP_POS_FRAMES, frame)
            _, frame = video.read()
            frames.append(frame)
        
        # Batch process frames
        batch_size = 32
        for i in range(0, len(frames), batch_size):
            batch_frames = frames[i:i+batch_size]
            batch_images = torch.stack([preprocess(Image.fromarray(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))) for f in batch_frames]).to(DEVICE)
            with torch.no_grad():
                batch_features = model.encode_image(batch_images)
            features.append(batch_features)
        
        features_tensor = torch.cat(features)
        shot_keyframes = keyframe_extraction(bound, features_tensor.cpu().numpy(), video)
        row = vid_insights.loc[(vid_insights['vid_name'] == vid_name) & (vid_insights['start'] == bound[0]), 'keyframes'].index.tolist()
        if len(row) == 1:
            vid_insights.at[row[0], 'keyframes'] = shot_keyframes
        else:
            raise Exception("Lỗi ngoài ý muốn! Liên hệ Phát!")
        print(shot_keyframes)
        
#         show_fml_keyframes(fml_frames(video, bound))
#         show_lmske_keyframes(video, shot_keyframes) 
#         print(time() - start)
    
    video.release()


0it [00:00, ?it/s]

[4345 4417]: Initializing centers ... - 

1it [00:16, 16.71s/it]

[4392, 4417]
[4418 4533]: Initializing centers ... - 

2it [00:31, 15.81s/it]

[4486]
[4534 4697]: Initializing centers ... - 

3it [00:58, 20.65s/it]

[4536, 4613, 4686]
[4698 4946]: Initializing centers ... - 

4it [01:29, 24.95s/it]

[4776]
[4947 5020]: Initializing centers ... - 

5it [01:36, 18.53s/it]

[4961, 5007]
[5021 5065]: Initializing centers ... - 

6it [01:44, 14.78s/it]

[5058]
[5066 5132]: Initializing centers ... - 

7it [01:52, 12.48s/it]

[5096]
[5133 5193]: Initializing centers ... - 

8it [01:59, 10.75s/it]

[5172]
[5194 5259]: Initializing centers ... - 

9it [02:08, 10.33s/it]

[5204]
[5260 5327]: Initializing centers ... - 

10it [02:24, 11.95s/it]

[5283]
[5328 5395]: Initializing centers ... - 

11it [02:32, 10.85s/it]

[5371]
[5396 5463]: Initializing centers ... - 

12it [02:45, 11.40s/it]

[5426]
[5464 5528]: Initializing centers ... - 

13it [02:53, 10.49s/it]

[5473]
[5529 5595]: Initializing centers ... - 

14it [03:05, 10.76s/it]

[5565]
[5596 5662]: Initializing centers ... - 

15it [03:12,  9.77s/it]

[5610]
[5663 5780]: Initializing centers ... - 

16it [03:28, 11.55s/it]

[5672, 5688, 5762]
[5781 5804]: Initializing centers ... - 

17it [03:33,  9.69s/it]

[5797]
[5805 5830]: Initializing centers ... - 

18it [03:40,  8.71s/it]

[5814]
[5831 5885]: Initializing centers ... - 

19it [03:46,  8.00s/it]

[5849]
[5886 5966]: Initializing centers ... - 

20it [03:55,  8.48s/it]

[5959]
[5967 6021]: Initializing centers ... - 

21it [04:01,  7.69s/it]

[5995]
[6022 6067]: Initializing centers ... - 

22it [04:07,  7.15s/it]

[6054]
[6068 6124]: Initializing centers ... - 

23it [04:17,  7.93s/it]

[6090]
[6125 6207]: Initializing centers ... - 

24it [04:27,  8.56s/it]

[6149]
[6208 6305]: Initializing centers ... - 

25it [04:37,  8.96s/it]

[6261]
[6306 6400]: Initializing centers ... - 

26it [04:48,  9.67s/it]

[6341, 6386]
[6401 6470]: Initializing centers ... - 

27it [05:04, 11.54s/it]

[6455]
[6471 6535]: Initializing centers ... - 

28it [05:13, 10.86s/it]

[6516]
[6536 6599]: Initializing centers ... - 

29it [05:25, 11.10s/it]

[6590]
[6600 6677]: Initializing centers ... - 

30it [05:35, 10.85s/it]

[6620]
[6678 6799]: Initializing centers ... - 

31it [05:52, 12.60s/it]

[6716, 6778]
[6800 6884]: Initializing centers ... - 

32it [06:03, 12.16s/it]

[6843]
[6885 6969]: Initializing centers ... - 

33it [06:12, 11.13s/it]

[6906]
[6970 7016]: Initializing centers ... - 

34it [06:17,  9.24s/it]

[7012]
[7017 7068]: Initializing centers ... - 

35it [06:24,  8.61s/it]

[7045]
[7069 7111]: Initializing centers ... - 

36it [06:33,  8.78s/it]

[7091]
[7112 7131]: Initializing centers ... - 

37it [06:38,  7.64s/it]

[7120]
[7132 7184]: Initializing centers ... - 

38it [06:44,  7.09s/it]

[7160]
[7185 7245]: Initializing centers ... - 

39it [06:54,  7.98s/it]

[7228]
[7246 7307]: Initializing centers ... - 

40it [07:02,  8.07s/it]

[7281]
[7308 7369]: Initializing centers ... - 

41it [07:12,  8.63s/it]

[7350]
[7370 7415]: Initializing centers ... - 

42it [07:22,  9.17s/it]

[7395]
[7416 7465]: Initializing centers ... - 

43it [07:29,  8.39s/it]

[7429]
[7466 7515]: Initializing centers ... - 

44it [07:35,  7.51s/it]

[7492]
[7516 7564]: Initializing centers ... - 

45it [07:42,  7.46s/it]

[7552]
[7565 7614]: Initializing centers ... - 

46it [07:48,  6.97s/it]

[7593]
[7615 7667]: Initializing centers ... - 

47it [07:55,  7.14s/it]

[7628]
[7668 7707]: Initializing centers ... - 

48it [08:04,  7.65s/it]

[7678]
[7708 7757]: Initializing centers ... - 

49it [08:11,  7.43s/it]

[7739]
[7758 7807]: Initializing centers ... - 

50it [08:18,  7.41s/it]

[7795]
[7808 7857]: Initializing centers ... - 

51it [08:29,  8.32s/it]

[7833]
[7858 7907]: Initializing centers ... - 

52it [08:36,  8.04s/it]

[7890]
[7908 7936]: Initializing centers ... - 

53it [08:40,  6.76s/it]

[7923]
[7937 7956]: Initializing centers ... - 

54it [08:43,  5.59s/it]

[7944]
[7957 8055]: Initializing centers ... - 

55it [08:56,  7.75s/it]

[7970]
[8056 8161]: Initializing centers ... - 

56it [09:10,  9.89s/it]

[8084]
[8162 8233]: Initializing centers ... - 

57it [09:26, 11.50s/it]

[8216]
[8234 8321]: Initializing centers ... - 

58it [09:36, 11.17s/it]

[8283]
[8322 8400]: Initializing centers ... - 

59it [09:46, 10.73s/it]

[8340, 8394]
[8401 8486]: Initializing centers ... - 

60it [10:02, 12.50s/it]

[8472]
[8487 8555]: Initializing centers ... - 

61it [10:11, 11.27s/it]

[8527]
[8556 8620]: Initializing centers ... - 

62it [10:19, 10.45s/it]

[8594]
[8621 8671]: Initializing centers ... - 

63it [10:25,  8.85s/it]

[8628]
[8672 9062]: Initializing centers ... - 

63it [11:11, 10.66s/it]


KeyboardInterrupt: 

In [None]:
vid_insights.to_csv('/kaggle/working/vid_insights.csv', sep='-', index=False)