In [71]:
from pathlib import Path
import numpy as np
import torch
import json
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from pathlib import Path
import random

from IPython.display import display
from PIL import Image
import decord
from einops import rearrange

np.set_printoptions(suppress=True)

In [73]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2').cuda().eval()

def load_and_select_caption(annotation_path, p_caps_dir, clip_uid=None):
    # Load annotations
    annotations = json.loads(Path(annotation_path).read_text())
    valid_video_id = [annotation['video_id'] for annotation in annotations]

    # List p_cap files
    p_caps = list(Path(p_caps_dir).glob('*.json'))

    # Select a random index
    idx = random.randint(0, len(valid_video_id) - 1)

    # Find the corresponding p_cap
    if clip_uid is not None:
        selected_p_cap = None
        for p_cap in p_caps:
            if clip_uid in p_cap.stem:
                selected_p_cap = p_cap
                break
    else:
        selected_p_cap = None
        for p_cap in p_caps:
            if valid_video_id[idx] in p_cap.stem:
                selected_p_cap = p_cap
                break

    if selected_p_cap is None:
        raise ValueError("No matching p_cap found.")

    # Load the caption data
    cap_data = json.load(selected_p_cap.open())['answers']

    # Convert time to seconds
    time = [int(entry[0] / 30) for entry in cap_data]

    # Extract captions
    caps = [entry[2] for entry in cap_data]

    return selected_p_cap, cap_data, time, caps

def plot_cosine_similarity_heatmap(cos_sim, time, p_cap, interval = 10):
    plt.figure(figsize=(10, 8))
    
    # Create the heatmap
    sns.heatmap(cos_sim, cmap='viridis')
    
    # Set x and y ticks
    plt.xticks(range(0, len(time), interval), [time[i] for i in range(0, len(time), interval)], rotation=90)
    plt.yticks(range(0, len(time), interval), [time[i] for i in range(0, len(time), interval)], rotation=0)
    
    # Optionally hide axes
    # plt.gca().axes.xaxis.set_visible(False)
    # plt.gca().axes.yaxis.set_visible(False)
    
    # Set labels and title
    plt.xlabel('Time (s)')
    plt.ylabel('Time (s)')
    plt.title(f'<Cosine Similarity Heatmap>\nclip_uid: {p_cap.stem}')
    
    # Show the plot
    plt.show()
    
    # Print clip_uid stem
    print(p_cap.stem)
    
# Example usage
annotation_path = '/data/joohyun7u/project/NLQ/nlq_lightning/data/unified/annotations.NLQ_train.json'
p_caps_dir = '/data/joohyun7u/project/NLQ/nlq_lightning/data/llava-v1.6-34b/global'
# p_caps_dir = '/data/joohyun7u/project/NLQ/nlq_lightning/data/LLaVA-NeXT-Video-7B-DPO/global_v2'
selected_p_cap, cap_data, time, caps = load_and_select_caption(annotation_path, p_caps_dir)

In [None]:
p_caps_dir = '/data/joohyun7u/project/NLQ/nlq_lightning/data/llava-v1.6-34b/global'
# p_caps_dir = '/data/joohyun7u/project/NLQ/nlq_lightning/data/LLaVA-NeXT-Video-7B-DPO/global_v2'
selected_p_cap, cap_data, time, caps = load_and_select_caption(annotation_path, p_caps_dir)

In [74]:
import h5py
import numpy as np
import torch
import os
data_dir = '/data/joohyun7u/project/NLQ/nlq_lightning/data/unified'
feature_type = 'egovlp_internvideo'

video_features = h5py.File(os.path.join(data_dir, feature_type + '.hdf5'), 'r')

np.random.seed(0)
p_nlq_val_json = Path('/data/joohyun7u/project/NLQ/nlq_lightning//data/unified/annotations.NLQ_val.json')
nlq_val_data = json.load(p_nlq_val_json.open())

def show_frames(clip_uid):
	# p_clip = Path('/data/datasets/ego4d_data/clips_320p-non_official/') / f'{clip_uid}.mp4'
	p_clip = Path('/data/datasets/ego4d_data/v2/clips/') / f'{clip_uid}.mp4'
	vr = decord.VideoReader(str(p_clip))
	num_frames = 16
	border = 10
	frame_idxs = np.arange(len(vr)//num_frames//2, len(vr), len(vr)//num_frames)
	border_colors = plt.get_cmap('jet')(frame_idxs/len(vr))
	midframes = vr.get_batch(frame_idxs.tolist()).asnumpy()
	out_frames = midframes.copy()
	for i, frame in enumerate(midframes):
		frame[:border] = border_colors[i][:3] * 255
		frame[-border:] = border_colors[i][:3] * 255
		frame[:, :border] = border_colors[i][:3] * 255
		frame[:, -border:] = border_colors[i][:3] * 255
	midframes = rearrange(midframes, '(th tw) h w c -> (th h) (tw w) c', tw=num_frames)
	display(Image.fromarray(midframes))
	return frame_idxs, out_frames


In [75]:
clip_list = list(video_features.keys())

In [None]:
import matplotlib.pyplot as plt
# clip_uid = '13635dfb-dfcc-4ada-bcee-1482a568aa89'
# clip_uid = '5e59031d-0deb-4557-a3e1-ba0ba2bb5465'
clip_uid = 'a99baf07-ce1c-4f73-ab20-ed0dfc079510'
clip_uid = '00d9a297-d967-4d28-8e5a-6b891814ec65'

# clip_list = list(video_features.keys())
# for i in range(10):
clip_uid = random.choice(clip_list)
clip_uid = clip_list[40]

selected_p_cap, cap_data, time, caps = load_and_select_caption(annotation_path, p_caps_dir, clip_uid=clip_uid)

frame_idxs, frames = show_frames(clip_uid)

data = video_features[clip_uid][:]
print(data.shape)

from sklearn.metrics.pairwise import cosine_similarity

# CLS 토큰 간 코사인 유사도 계산
similarity_matrixs = []
# similarity_matrix.append
similarity_matrix = cosine_similarity(data)  # shape: (900, 900)
# similarity_matrix = cosine_similarity(data[:,:256])  # shape: (900, 900)
# similarity_matrix = cosine_similarity(data[:,256:256+1024])  # shape: (900, 900)
# similarity_matrix = cosine_similarity(data[:,256+1024:])  # shape: (900, 900)

# similarity_matrix = data @ data.T

# 히트맵 시각화
plt.figure(figsize=(10, 8))
# plt.imshow(similarity_matrix, cmap='coolwarm')
plt.imshow(similarity_matrix, cmap='viridis')
plt.title(f'CLS Token Similarity Heatmap\nclip_uid: {clip_uid}')
plt.colorbar(label='Cosine Similarity')
plt.xlabel('CLS Token Index')
plt.ylabel('CLS Token Index')
plt.show()


embeddings = model.encode(caps)
cos_sim = embeddings @ embeddings.T
# cos_sim = cosine_similarity(embeddings)

plot_cosine_similarity_heatmap(cos_sim, time, selected_p_cap)

In [77]:
p_caps_dir

'/data/joohyun7u/project/NLQ/nlq_lightning/data/llava-v1.6-34b/global'

In [None]:
clip_uid

'0a798ad9-e163-4d1a-9a26-0ba6a8dce89e'