In [None]:
import os
import pickle
import hashlib
import numpy as np
from collections import defaultdict
from array_record.python.array_record_module import ArrayRecordReader
import multiprocessing
from tqdm import tqdm  # Using tqdm for a nice progress bar


In [None]:

# --- Helper Functions ---


def hash_byte_data(bytes_data: bytes) -> str:
    """Calculates the SHA256 hash of a byte string."""
    return hashlib.sha256(bytes_data).hexdigest()


def hash_numpy_frame(frame: np.ndarray) -> str:
    """Calculates the SHA256 hash of a numpy array."""
    return hashlib.sha256(np.ascontiguousarray(frame)).hexdigest()


def get_episode_level_hash(video_path: str) -> tuple[str, str] | None:
    """
    Reads a single array_record file, extracts the video, hashes it,
    and returns the (hash, video_path) tuple.
    """
    try:
        reader = ArrayRecordReader(video_path)
        record_data = reader.read()
        record_unpickled = pickle.loads(record_data)
        video_bytes = record_unpickled["raw_video"]
        video_hash = hash_byte_data(video_bytes)
        return video_hash, video_path
    except Exception as e:
        print(f"Error processing file {video_path}: {e}")
        return None


def get_frame_level_hashes(video_path: str) -> tuple[list[str], str] | None:
    """
    Reads a single array_record file, extracts the frames, hashes each frame,
    and returns the (list of hashes, video_path) tuple.
    """
    try:
        reader = ArrayRecordReader(video_path)
        record_data = pickle.loads(reader.read())

        # video shape (seq_len, 64, 64, 3)
        video_shape = (record_data["sequence_length"], 64, 64, 3)
        episode_tensor = np.frombuffer(record_data["raw_video"], dtype=np.uint8)
        episode_tensor = episode_tensor.reshape(video_shape)

        frame_hashes = [hash_numpy_frame(frame) for frame in episode_tensor]

        return frame_hashes, video_path
    except Exception as e:
        print(f"Error processing file {video_path}: {e}")
        return None


def get_array_record_files(dir):
    return [
        os.path.join(dir, x) for x in os.listdir(dir) if x.endswith(".array_record")
    ]



In [None]:
# Set the base data directory
base = "data/coinrun_episodes"

train_dir = os.path.join(base, "train")
test_dir = os.path.join(base, "test")
val_dir = os.path.join(base, "val")

train_array_record_files = get_array_record_files(train_dir)
test_array_record_files = get_array_record_files(test_dir)
val_array_record_files = get_array_record_files(val_dir)

array_record_files = (
    train_array_record_files + test_array_record_files + val_array_record_files
)
print(f"Found {len(array_record_files)} files to process.")

num_processes = multiprocessing.cpu_count()
print(f"Using {num_processes} worker processes.")


In [None]:
# Find episode level duplicates
duplicate_episode = defaultdict(list)

# The 'with' statement ensures the pool is properly closed
with multiprocessing.Pool(processes=num_processes) as pool:
    # Use pool.imap_unordered for efficiency. It processes items as they are submitted
    # and returns results as they complete, which is perfect for progress bars.
    # We wrap the result iterator with tqdm to show progress.
    results = pool.imap_unordered(get_episode_level_hash, array_record_files)

    print("Processing files and calculating hashes...")
    for result in tqdm(results, total=len(array_record_files)):
        if result:  # Ensure the worker didn't return None due to an error
            video_hash, video_path = result
            duplicate_episode[video_hash].append(video_path)

print("\nAggregation complete. Finding duplicates...")
duplicates = {h: paths for h, paths in duplicate_episode.items() if len(paths) > 1}

print(f"\nFound {len(duplicates)} sets of duplicate videos.")
if duplicates:
    # Print the first 5 duplicate sets as an example
    for i, (h, paths) in enumerate(duplicates.items()):
        print(f"  - Hash: {h[:10]}... ({len(paths)} files)")
        for path in paths:
            print(f"    - {os.path.basename(path)}")


In [None]:
# Frame level duplicates
# This dictionary will map a frame hash to a list of (video_path, frame_index)
frame_dup_dict = defaultdict(list)

with multiprocessing.Pool(processes=num_processes) as pool:
    results = pool.imap_unordered(get_frame_level_hashes, array_record_files)

    print("Processing files and calculating frame hashes...")
    for result in tqdm(results, total=len(array_record_files)):
        if result:
            frame_hashes, video_path = result
            for frame_idx, frame_hash in enumerate(frame_hashes):
                frame_dup_dict[frame_hash].append((video_path, frame_idx))

print("\nAggregation complete. Finding duplicate frames...")
duplicate_frames = {
    hash: location for hash, location in frame_dup_dict.items() if len(location) > 1
}


total_frames = sum(len(locations) for locations in frame_dup_dict.values())
print(f"Total frames: {total_frames}")
num_duplicate_frames = len(duplicate_frames.keys())
print(f"Number of duplicate frames: {num_duplicate_frames}")
percentage = num_duplicate_frames / total_frames
print(f"Percentage of duplicate frames: {percentage:.2%}")
