In [None]:
import argparse
import glob
import os
from time import time
from utils.build_chunks import build_chunks
import numpy as np
import pandas as pd
import shutil
from pathlib import Path
from tqdm import tqdm
from pairs.utils import s3_save_numpy
from utils.lip_coordinates import extract_and_crop_lips
import dlib
import boto3
import re
import torch
import os
import torch
from torch.utils.data import Dataset
from pathlib import Path
import numpy as np
import random
import pandas as pd
from typing import List, Dict, Tuple
from math import floor
import re
# from pairs.utils import list_s3_files, s3_load_numpy
# from pairs.config import S3_BUCKET_NAME
from tqdm import tqdm
from collections import defaultdict
# from datasets.LoaderTest import TestDataset

s3        = boto3.client("s3")
paginator = s3.get_paginator("list_objects_v2")

bucket       = "mmml-proj"
root_prefix  = "test_preprocessed/" 
# Load the test dataset
test_dataset = TestDataset(bucket, root_prefix, transform=None,visual_type='lip' )
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
print(f"Test dataset size: {len(test_dataset)} samples")
    
    # Load the trained model
model = load_model(checkpoint_path, embedding_dims=512)
    
    # Run inference and save predictions
# run_inference(model, test_loader, output_csv="predictions.csv")

In [1]:
pwd

'/Users/AnuranjanAnand/Desktop/MML/mml_diarization'

In [None]:
import boto3, io, re
from collections import defaultdict
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

# ─────────────────────────────────────────────────────────────
# helpers
# ─────────────────────────────────────────────────────────────
_s3        = boto3.client("s3")
_paginator = _s3.get_paginator("list_objects_v2")

def _load_npy(bucket: str, key: str) -> np.ndarray:
    """Load a .npy object in-memory from S3."""
    body = _s3.get_object(Bucket=bucket, Key=key)["Body"].read()
    return np.load(io.BytesIO(body))

def _load_csv(bucket: str, key: str) -> Optional[pd.DataFrame]:
    try:
        body = _s3.get_object(Bucket=bucket, Key=key)["Body"].read()
    except _s3.exceptions.NoSuchKey:
        return None
    return pd.read_csv(io.BytesIO(body))

def _extract_audio_segment(
    mel: torch.Tensor,
    frame_idx: int,
    total_frames: int,
    desired_len: int = 22,
) -> torch.Tensor:
    """Slice or pad a mel spectrogram to length `desired_len` for the given frame."""
    if mel.ndim == 3:                                # [n_mels, L, T]
        seg = mel[:, :, frame_idx]
    else:                                            # [n_mels, A]
        n_mels, A = mel.shape
        a0 = int(frame_idx * A / total_frames)
        a1 = max(a0 + 1, int((frame_idx + 1) * A / total_frames))
        seg = mel[:, a0:a1]
        cur = seg.shape[1]
        if cur < desired_len:
            seg = torch.nn.functional.pad(seg, (0, desired_len - cur))
        elif cur > desired_len:
            seg = seg[:, :desired_len]
    return seg                                       # [n_mels, desired_len]

def _iter_prefixes(bucket: str, root: str):
    """Yield immediate sub-prefixes (video folders) under `root`."""
    for page in _paginator.paginate(Bucket=bucket, Prefix=root, Delimiter='/'):
        for cp in page.get("CommonPrefixes", []):
            yield cp["Prefix"]                       # “…/<video_id>/”


In [None]:
from tqdm.auto import tqdm
class TestDataset(Dataset):
    """
    S3-backed test dataset for audio-visual diarization.
    Supports `visual_type='face'` or `'lip'`.
    Returns (face_tensor, mel_segment, label, metadata).
    """
    def __init__(
        self,
        bucket: str,
        prefix: str,
        transform=None,
        visual_type: str = "face",
        verbose: bool = True,
    ):
        super().__init__()
        assert visual_type in {"face", "lip"}
        self.bucket      = bucket
        self.prefix      = prefix.rstrip("/") + "/"       # ensure trailing /
        self.transform   = transform
        self.visual_type = visual_type
        self.verbose     = verbose

        self.samples: List[Dict] = []
        vid2speakers     = defaultdict(set)

        # ───────── 1. iterate videos ─────────
        video_prefixes = list(_iter_prefixes(bucket, self.prefix))
        for vid_prefix in tqdm(video_prefixes, desc="Scanning videos"):
            video_id   = vid_prefix.split("/")[-2]
            csv_key    = f"{vid_prefix}is_speaking.csv"
            df_labels  = _load_csv(bucket, csv_key)
            speak_map  = {}
            if df_labels is not None:
                for r in df_labels.itertuples():
                    speak_map[(int(r.face_id), int(r.frame_id))] = int(r.is_speaking)
                    vid2speakers[video_id].add(int(r.face_id))
            elif self.verbose:
                print(f"[warn] {csv_key} missing – all labels default to 0")

            # ─────── 2. iterate Chunk_<id>/ folders ───────
            for page in _paginator.paginate(
                Bucket=bucket, Prefix=vid_prefix, Delimiter='/'
            ):
                for cp in page.get("CommonPrefixes", []):
                    chunk_prefix = cp["Prefix"]                      # …/Chunk_3/
                    chunk_id     = chunk_prefix.split("/")[-2].split("_")[-1]

                    # list keys in this chunk
                    keys = []
                    for p in _paginator.paginate(Bucket=bucket, Prefix=chunk_prefix):
                        keys.extend(obj["Key"] for obj in p.get("Contents", []))

                    mel_key = next((k for k in keys if k.endswith("melspectrogram.npy")), None)
                    if mel_key is None:
                        if self.verbose:
                            print(f"[warn] no mel in {chunk_prefix}")
                        continue

                    face_pattern = f"{self.visual_type}_(\\d+)\\.npy$"
                    face_keys = [
                        k for k in keys
                        if re.search(face_pattern, k)
                        and not k.endswith("_bboxes.npy")
                    ]
                    if not face_keys:
                        if self.verbose:
                            print(f"[warn] no {self.visual_type} tracks in {chunk_prefix}")
                        continue

                    # ─────── 3. build sample list ───────
                    for fk in face_keys:
                        speaker_id = int(re.search(face_pattern, fk).group(1))
                        face_arr   = _load_npy(bucket, fk)            # [T, C, H, W]
                        T          = face_arr.shape[0]

                        for frame_idx in range(T):
                            label = speak_map.get((speaker_id, frame_idx), 0)
                            self.samples.append({
                                "face_key" : fk,
                                "mel_key"  : mel_key,
                                "frame_idx": frame_idx,
                                "label"    : float(label),
                                "meta"     : {
                                    "video_id"    : video_id,
                                    "chunk_id"    : chunk_id,
                                    "speaker_id"  : speaker_id,
                                    "frame_idx"   : frame_idx,
                                    "total_frames": T,
                                    # num_speakers set later
                                },
                            })

        self.vid2nspeakers = {v: len(s) for v, s in vid2speakers.items()}
        # fill num_speakers in meta
        for s in self.samples:
            s["meta"]["num_speakers"] = self.vid2nspeakers.get(s["meta"]["video_id"], 0)

        if self.verbose:
            print(f"[info] built TestDataset with {len(self.samples):,} samples")

    # ─────────────────────────────────────────────────────────
    # PyTorch hooks
    # ─────────────────────────────────────────────────────────
    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        rec   = self.samples[idx]
        meta  = rec["meta"]

        # face frame
        face_arr   = _load_npy(self.bucket, rec["face_key"])
        frame      = torch.from_numpy(face_arr[rec["frame_idx"]]).float()
        if self.transform:
            frame = self.transform(frame)

        # mel segment
        mel_arr    = _load_npy(self.bucket, rec["mel_key"])
        mel_tensor = torch.from_numpy(mel_arr).float()
        mel_seg    = _extract_audio_segment(
            mel_tensor, rec["frame_idx"], meta["total_frames"], desired_len=22
        )

        return frame, mel_seg, rec["label"], meta