In [1]:
from lutils import openf, writef
from tqdm import tqdm
import numpy as np

from pathlib import Path
from collections import defaultdict

# Add this import at the top of the notebook
from typing import List, Dict
import sys


sys.path.append(str(Path().cwd().parent.parent))
from src.data.utils_asr import ChaptersASR
from tools.captions.caption_selection import CaptionSelection

### Common

In [2]:
def select_furthest_frames_dict(frame_dict, target_count=100):
    # Step 1: Parse the frame index from the key (assuming the format is 'frame_idx/total_frames')
    frame_keys = list(frame_dict.keys())
    frames = [(int(key.split("/")[0]), key) for key in frame_keys]

    # Step 2: Sort frames by their frame index
    frames.sort(key=lambda x: x[0])

    # Step 3: Extract the frame indices for processing
    frame_indices = [f[0] for f in frames]

    # Step 4: Iteratively reduce frames until we have 100 left
    while len(frame_indices) > target_count:
        distances = [
            (frame_indices[i + 1] - frame_indices[i], i)
            for i in range(len(frame_indices) - 1)
        ]

        # Find the shortest distance and remove the corresponding frame
        min_distance_idx = min(distances, key=lambda x: x[0])[1]

        # Ensure min_distance_idx + 1 is within bounds
        if (
            min_distance_idx + 1 < len(distances)
            and distances[min_distance_idx][0] <= distances[min_distance_idx + 1][0]
        ):
            frame_indices.pop(min_distance_idx + 1)
            frames.pop(min_distance_idx + 1)
        else:
            frame_indices.pop(min_distance_idx)
            frames.pop(min_distance_idx)

    # Step 5: Rebuild the dictionary using the remaining keys
    filtered_dict = {key: frame_dict[key] for _, key in frames}

    return filtered_dict


class CaptionSelector:
    def __init__(
        self,
        captions_dir: Path,
        vidc_dir: Path = Path("../../dataset/"),
        subsets: List[str] = ("s1k_train", "s100_val"),
        max_gap: float = 2,
    ):
        subsets = [subsets] if isinstance(subsets, str) else subsets
        self.subset2chp = {
            subset: ChaptersASR(vidc_dir=vidc_dir, subset=subset) for subset in subsets
        }
        self.subsets = subsets
        video_ids = []
        for subset in subsets:
            video_ids.extend(self.subset2chp[subset].video_ids)
        # assert len(video_ids) == len(set(video_ids)), "Duplicate video IDs"
        video_ids = list(set(video_ids))
        self.video_ids = video_ids
        self.captions_dir = captions_dir

        self.vid2subset = {}
        for subset in subsets:
            for vid_id in self.subset2chp[subset].video_ids:
                self.vid2subset[vid_id] = subset

        self.max_gap = max_gap
        self.vid2timestamps = defaultdict(list)

    def get_chp(self, string: str):
        if string in self.subsets:
            return self.subset2chp[string]
        elif string in self.vid2subset:
            return self.subset2chp[self.vid2subset[string]]
        else:
            raise ValueError(f"{string} not in subsets or vid2subset")

    def get_duration(self, string: str):
        return self.get_chp(string).get_duration(string)

    def get_captions(self, vid_id):
        captions_file = self.captions_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
        all_captions = openf(captions_file)
        return all_captions

    def select_captions(
        self, vid_id: str, timestamps: List[List[float]], max_captions: int = 100
    ) -> Dict[str, str]:
        vid_duration = self.get_duration(vid_id)
        vid_captions = self.get_captions(vid_id)

        selected_captions = {}

        for timestamp in timestamps:
            closest_frame = self.find_closest_frame(
                timestamp, vid_duration, vid_captions
            )
            if closest_frame is None:
                self.vid2timestamps[vid_id].append(timestamp)
                continue

            if closest_frame not in selected_captions:
                selected_captions[closest_frame] = vid_captions[closest_frame]

        if len(selected_captions) >= max_captions:
            # filter captions, remove the ones if they are very close to each other
            selected_captions = select_furthest_frames_dict(
                selected_captions, max_captions
            )

        # sort by frame index
        selected_captions = dict(
            sorted(selected_captions.items(), key=lambda x: int(x[0].split("/")[0]))
        )

        return selected_captions

    def extract_timestamps(self, asr_segments: List[List[float]]) -> List[float]:
        timestamps = []
        for start, end in asr_segments:
            timestamps.extend([start, end])
        return sorted(set(timestamps))

    def find_closest_frame(
        self, timestamp: float, vid_duration: float, captions: Dict[str, str]
    ) -> str:
        target_ratio = timestamp / vid_duration
        closest_frame = min(
            captions.keys(),
            key=lambda x: abs(
                float(x.split("/")[0]) / float(x.split("/")[1]) - target_ratio
            ),
        )
        # Verify the selected frame is within 2 seconds of target timestamp
        frame_num, total_frames = map(float, closest_frame.split("/"))
        frame_timestamp = (frame_num / total_frames) * vid_duration
        if abs(frame_timestamp - timestamp) > self.max_gap:
            return None
        return closest_frame

In [3]:
captioner = "HwwwH_MiniCPM-V-2"
captions_dir = Path(f"../../dataset/captions/{captioner}/")

subsets = [
    "sml1k_train",
    "sml10k_train",
    "sml300_val",
]
subsets = ["sml10k_train", "sml300_val"]
subsets = ["sml1k_train", "sml300_val"]
subsets = ["sml_no-asr_val"]
cs = CaptionSelector(captions_dir / "all", subsets=subsets, max_gap=3)

# s10k-2_train

In [None]:
base_dir = Path("../../")
vidc_dir = base_dir / "dataset/"

subset_train = "sml10k-2_train"
subset_train = "s1k-2_train"
subset_train = "sml1k_train"
subset_train = "s10k-2_train"

selection = CaptionSelection(
    vidc_dir=vidc_dir,
    base_dir=base_dir,
    sampling_methods=("asr-preds",),
    subset_train=subset_train,
)

method_flag = f"asr_{subset_train}_preds"


captions_dir = Path(f"../../dataset/captions/{captioner}/{method_flag}")
captions_dir.mkdir(exist_ok=True)

missing_vids = []
for vid_id in tqdm(cs.video_ids):
    caption_pth = captions_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
    # if caption_pth.exists():
    #     continue

    vid_duration = cs.get_duration(vid_id)
    no_preds_vid = []
    try:
        timestamps = selection(vid_id, duration=vid_duration)

        selected_captions = cs.select_captions(vid_id, timestamps)
        if not selected_captions:
            no_preds_vid.append(vid_id)
            continue
        vid_caption_dir = captions_dir / f"{vid_id[:2]}"
        vid_caption_dir.mkdir(exist_ok=True)
        writef(vid_caption_dir / f"{vid_id}.json", selected_captions)
    except Exception as e:
        # print(f"Error for {vid_id}: {e}")
        missing_vids.append(vid_id)
        continue
print(
    f"No preds for {len(no_preds_vid)} videos ({len(no_preds_vid) / len(cs.video_ids):.2%})"
)
print(
    f"Missing {len(missing_vids)} videos ({len(missing_vids) / len(cs.video_ids):.2%})"
)

10s

In [4]:
base_dir = Path("../../")
vidc_dir = base_dir / "dataset/"

selection = CaptionSelection(
    vidc_dir=vidc_dir,
    base_dir=base_dir,
    sampling_methods=("10s",),
)
method_flag = "10s"


captions_dir = Path(f"../../dataset/captions/{captioner}/{method_flag}")
captions_dir.mkdir(exist_ok=True)

for vid_id in tqdm(cs.video_ids):
    caption_pth = captions_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
    vid_duration = cs.get_duration(vid_id)
    try:
        timestamps = selection(vid_id, duration=vid_duration)

        selected_captions = cs.select_captions(vid_id, timestamps)
        vid_caption_dir = captions_dir / f"{vid_id[:2]}"
        vid_caption_dir.mkdir(exist_ok=True)
        writef(vid_caption_dir / f"{vid_id}.json", selected_captions)
    except Exception as e:
        print(f"Error for {vid_id}: {e}")
        continue

100%|██████████| 190/190 [00:10<00:00, 17.75it/s]


Shot detection

In [None]:
base_dir = Path("../../")
vidc_dir = base_dir / "dataset/"

sampling_method, shot_location = "boundary", "boundaries"
sampling_method, shot_location = "midpoint", "midpoints"

selection = CaptionSelection(
    vidc_dir=vidc_dir,
    base_dir=base_dir,
    sampling_methods=(f"shot-{sampling_method}",),
)

method_flag = f"shot_{shot_location}"

captions_dir = Path(f"../../dataset/captions/{captioner}/{method_flag}")
captions_dir.mkdir(exist_ok=True)

for vid_id in tqdm(cs.video_ids):
    caption_pth = captions_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
    vid_duration = cs.get_duration(vid_id)
    try:
        timestamps = selection(vid_id, duration=vid_duration)

        selected_captions = cs.select_captions(vid_id, timestamps)
        vid_caption_dir = captions_dir / f"{vid_id[:2]}"
        vid_caption_dir.mkdir(exist_ok=True)
        writef(vid_caption_dir / f"{vid_id}.json", selected_captions)
    except Exception as e:
        print(f"Error for {vid_id}: {e}")
        continue

60s

In [None]:
base_dir = Path("../../")
vidc_dir = base_dir / "dataset/"

selection = CaptionSelection(
    vidc_dir=vidc_dir,
    base_dir=base_dir,
    sampling_methods=("60s",),
)
method_flag = "60s"


captions_dir = Path(f"../../dataset/captions/{captioner}/{method_flag}")
captions_dir.mkdir(exist_ok=True)

for vid_id in tqdm(cs.video_ids):
    caption_pth = captions_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
    vid_duration = cs.get_duration(vid_id)
    try:
        timestamps = selection(vid_id, duration=vid_duration)

        selected_captions = cs.select_captions(vid_id, timestamps)
        vid_caption_dir = captions_dir / f"{vid_id[:2]}"
        vid_caption_dir.mkdir(exist_ok=True)
        writef(vid_caption_dir / f"{vid_id}.json", selected_captions)
    except Exception as e:
        continue

100f

In [None]:
base_dir = Path("../../")
vidc_dir = base_dir / "dataset/"

method_flag = "100f"
method_flag = "10f"
selection = CaptionSelection(
    vidc_dir=vidc_dir,
    base_dir=base_dir,
    sampling_methods=(method_flag,),
)


captions_dir = Path(f"../../dataset/captions/{captioner}/{method_flag}")
captions_dir.mkdir(exist_ok=True)

for vid_id in tqdm(cs.video_ids):
    caption_pth = captions_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
    vid_duration = cs.get_duration(vid_id)
    try:
        timestamps = selection(vid_id, duration=vid_duration)

        selected_captions = cs.select_captions(vid_id, timestamps)
        vid_caption_dir = captions_dir / f"{vid_id[:2]}"
        vid_caption_dir.mkdir(exist_ok=True)
        assert len(selected_captions) == 10
        writef(vid_caption_dir / f"{vid_id}.json", selected_captions)
    except Exception as e:
        continue

### ASR + 10s if no ASR

In [5]:
captioner = "openbmb_MiniCPM-V-2_6"
captioner = "HwwwH_MiniCPM-V-2"

caption_dir = Path(f"../../dataset/captions/{captioner}/")
data1 = "asr_sml10k-2_train_preds"
data1 = "asr_sml1k_train_preds"
data1 = "asr_s1k-2_train_preds"
data1 = "asr_s10k-2_train_preds"

asr_preds_dir = caption_dir / data1

data2 = "10s"
data2_dir = Path(f"../../dataset/captions/{captioner}/{data2}")

add_prefix = False

# Create a directory for the combined approach
combined_stem = asr_preds_dir.stem + f"+no-asr-{data2}"

combined_stem += "_captionsWithPrefix" if add_prefix else ""

print(f"Data: {combined_stem}")
combined_dir = caption_dir / combined_stem
combined_dir.mkdir(exist_ok=True)

for vid_id in tqdm(cs.video_ids):
    vid_duration = cs.get_duration(vid_id)

    # Get ASR captions
    asr_pth = asr_preds_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
    s10_pth = data2_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
    if vid_id in cs.get_chp(vid_id) and asr_pth.exists():
        vid_captions = openf(asr_pth)
        if add_prefix:
            vid_captions = {k: "<speech> " + v for k, v in vid_captions.items()}

    # elif s10_pth.exists():
    elif s10_pth.exists() and not asr_pth.exists():
        vid_captions = openf(s10_pth)
        if add_prefix:
            vid_captions = {k: "<visual> " + v for k, v in vid_captions.items()}
    else:
        print(f"{vid_id} has no captions")
        continue

    vid_captions = select_furthest_frames_dict(vid_captions, 100)

    # Write combined captions
    combined_file = combined_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
    combined_file.parent.mkdir(exist_ok=True)
    writef(combined_file, vid_captions)

print(f"{combined_dir.resolve()} {combined_stem}")

Data: asr_s10k-2_train_preds+no-asr-10s


100%|██████████| 190/190 [00:11<00:00, 16.37it/s]

/storage/lucas/datasets/VidChapters/captions/HwwwH_MiniCPM-V-2/asr_s10k-2_train_preds+no-asr-10s asr_s10k-2_train_preds+no-asr-10s





### Proprietary models

In [None]:
base_dir = Path("../../")
vidc_dir = base_dir / "dataset/"

model = "gpt-4o"
model = "gemini-1.5-pro"
model = "gpt-4o-mini"
model = "gemini-2.0-flash"
selection = CaptionSelection(
    vidc_dir=vidc_dir,
    base_dir=base_dir,
    sampling_methods=("asr-preds",),
    model=model,
    subset_train="zero-shot",
)

method_flag = f"asr_{model}_zero-shot"

captions_dir = Path(f"../../dataset/captions/{captioner}/{method_flag}")
captions_dir.mkdir(exist_ok=True)
print("Saving data in:", captions_dir.resolve())

no_preds_vid = []
for vid_id in tqdm(cs.video_ids):
    caption_pth = captions_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
    # if caption_pth.exists():
    #     continue

    vid_duration = cs.get_duration(vid_id)
    try:
        timestamps = selection(vid_id, duration=vid_duration)

        selected_captions = cs.select_captions(vid_id, timestamps)
        if not selected_captions:
            no_preds_vid.append(vid_id)
            continue
        vid_caption_dir = captions_dir / f"{vid_id[:2]}"
        vid_caption_dir.mkdir(exist_ok=True)
        writef(vid_caption_dir / f"{vid_id}.json", selected_captions)
    except Exception as e:
        print(f"Error for {vid_id}: {e}")
        continue

print(
    f"No preds for {len(no_preds_vid)} videos ({len(no_preds_vid) / len(cs.video_ids):.2%})"
)

files_in_captions_dir = list(captions_dir.glob("**/*.json"))
vids_in_captions_dir = {p.stem for p in files_in_captions_dir}
vids_not_in_captions_dir = [
    vid for vid in cs.video_ids if vid not in vids_in_captions_dir
]

print(len(vids_not_in_captions_dir))

### No ASR predictions from 10s

In [None]:
base_dir = Path("../../")
vidc_dir = base_dir / "dataset/"

selection = CaptionSelection(
    vidc_dir=vidc_dir,
    base_dir=base_dir,
    sampling_methods=("asr-preds",),
    data_flags="10s",
    prompt="captions",
    subset_train="s1k-2_no-asr_train",
)

method_flag = "captions10s_s1k-2_no-asr_train"

captions_dir = Path(f"../../dataset/captions/{captioner}/{method_flag}")
captions_dir.mkdir(exist_ok=True)

for vid_id in tqdm(cs.video_ids):
    caption_pth = captions_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
    # if caption_pth.exists():
    #     continue

    vid_duration = cs.get_duration(vid_id)
    no_preds_vid = []
    try:
        timestamps = selection(vid_id, duration=vid_duration)

        selected_captions = cs.select_captions(vid_id, timestamps)
        if not selected_captions:
            no_preds_vid.append(vid_id)
            continue
        vid_caption_dir = captions_dir / f"{vid_id[:2]}"
        vid_caption_dir.mkdir(exist_ok=True)
        writef(vid_caption_dir / f"{vid_id}.json", selected_captions)
    except Exception as e:
        print(f"Error for {vid_id}: {e}")
        continue

print(
    f"No preds for {len(no_preds_vid)} videos ({len(no_preds_vid) / len(cs.video_ids):.2%})"
)
print(f"saved at {captions_dir.resolve()}")