In [None]:
import os, sys
import pandas as pd
from torchvision import transforms
from collections import defaultdict
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import utils
import collections
import re
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset
import types
from collections import defaultdict
from torchvision import transforms


def create_images_360(video_dir= '/Users/nagasaithadishetty/Desktop/WEB360/videos', image_size=(2048, 1024), allowed_serials=None):
    for video in sorted(os.listdir(video_dir)):
        match = re.compile(r'^(\d{6})\.mp4$').match(video)
        if not match:
            continue
        serial_number = match.group(1)
        output_dir = os.path.join(video_dir,"Frames_2", f"{serial_number}")
        if os.path.exists(output_dir):
            continue
        os.makedirs(output_dir, exist_ok=True)
        video_path = os.path.join(video_dir, video)
        cap = cv2.VideoCapture(video_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        print(f"* Processing '{video}' — {frame_count} frames")
        for i in range(frame_count):
            ret, frame = cap.read()
            if not ret:
                continue
            frame = cv2.resize(frame, image_size, interpolation=cv2.INTER_AREA)
            image_path = os.path.join(output_dir, f"{i+1}.png")
            cv2.imwrite(image_path, frame)
        cap.release()
        print(f"Created {frame_count} frames in '{output_dir}'\n")

class coordiantes_3d():
    def __init__(self, mini_batch_images, transform=None):
        self.mini_batch_images = mini_batch_images
        self.transform = transform
        self.utils = types.SimpleNamespace()
        self.utils.image_size = (128, 256)
        
    def __len__(self):
        return len(self.mini_batch_images)
    
    def __getitem__(self, idx):
        image_path = self.mini_batch_images[idx][0]
        scanpaths = self.mini_batch_images[idx][1]
        self.utils.image_size = (128, 256)

        print(f"Processing image: {image_path} with scanpaths: {scanpaths}")
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (4096, 2048), interpolation=cv2.INTER_AREA)

        original_h, original_w, _ = image.shape
        ratio_w = original_w / self.utils.image_size[1]
        ratio_h = original_h / self.utils.image_size[0]
        image = (cv2.resize(image, (self.utils.image_size[1], self.utils.image_size[0]), interpolation=cv2.INTER_AREA)).astype(np.float32) / 255.0

        images = []
        scanpaths_3d = []
        scanpaths_original = []

        for scanpath in scanpaths:
            flat_scanpath = []
            for i in range(0, len(scanpath), 2):
                y = scanpath[i + 1]
                x = scanpath[i]
                flat_scanpath.append(x / ratio_w)
                flat_scanpath.append(y / ratio_h)

            for i in range(0, len(flat_scanpath), 2):
                flat_scanpath[i] = flat_scanpath[i] / self.utils.image_size[1]
                flat_scanpath[i] = ((flat_scanpath[i] * 2) - 1) * (np.pi - 1e-2)
                flat_scanpath[i + 1] = flat_scanpath[i + 1] / self.utils.image_size[0]
                flat_scanpath[i + 1] = ((flat_scanpath[i + 1] * 2) - 1) * (np.pi / 2 - 1e-2)

            three_coord = []
            for i in range(0, len(flat_scanpath), 2):
                lon = flat_scanpath[i]
                lat = flat_scanpath[i + 1]
                x = np.cos(lat) * np.cos(lon)
                y = np.cos(lat) * np.sin(lon)
                z = np.sin(lat)
                three_coord.extend([x, y, z])

            images.append(self.transform(image) if self.transform else torch.tensor(image).permute(2, 0, 1))
            scanpaths_3d.append(torch.FloatTensor(three_coord))
            scanpaths_original.append(torch.FloatTensor(flat_scanpath))
        return images, scanpaths_3d, scanpaths_original

create_images_360(video_dir= '/Users/nagasaithadishetty/Desktop/WEB360/videos', image_size=(2048, 1024), allowed_serials=None)

In [None]:
import os
import pandas as pd
import numpy as np
from collections import defaultdict
import types
import re
import pickle
import csv

#remving outliers and interpolating them for consistent scanpaths
class ScanpathData360:
    def __init__(self, dataset_root, scan_path_directory):
        self.dataset_root = dataset_root
        self.scan_path_directory = scan_path_directory
        self.images = []
        self.image_scanpath_pairs = []
        self.image_and_scanpath_dict = defaultdict(dict)
        self.load_dataset()
        self.cleaned = defaultdict(lambda: defaultdict(dict))

    @staticmethod
    def great_circle_distance(phi1, lambda1, phi2, lambda2):
        delta_lambda = np.abs(lambda1 - lambda2)
        K = (np.sin(np.radians(phi1)) * np.sin(np.radians(phi2))) + \
            (np.cos(np.radians(phi1)) * np.cos(np.radians(phi2)) * np.cos(np.radians(delta_lambda)))
        return np.arccos(np.clip(K, -1, 1))  

    def latlon_to_unitvec(self, lat, lon):
        lat_rad = np.radians(lat)
        lon_rad = np.radians(lon)
        x = np.cos(lat_rad) * np.cos(lon_rad)
        y = np.cos(lat_rad) * np.sin(lon_rad)
        z = np.sin(lat_rad)
        return np.array([x, y, z])

    def unitvec_to_latlon(self, vec):
        x, y, z = vec
        lat = np.degrees(np.arcsin(z))  
        lon = np.degrees(np.arctan2(y, x)) 
        return (lat, lon)

    def slerp(self, q1, q2, t, tm):
        dot = np.dot(q1, q2)
        dot = np.clip(dot, -1.0, 1.0)
        if dot > 0.9995:
            result = q1 + (t / tm) * (q2 - q1)
            return result / np.linalg.norm(result)
        theta_0 = np.arccos(dot) 
        sin_theta_0 = np.sin(theta_0)
        theta = theta_0 * (t / tm)
        sin_theta = np.sin(theta)

        s1 = np.sin(theta_0 - theta) / sin_theta_0
        s2 = sin_theta / sin_theta_0
        return (s1 * q1) + (s2 * q2)

    def remove_outliers(self, scanpath, threshold_factor=2.5):
        if len(scanpath) < 2:
            return scanpath, []

        distances = [self.great_circle_distance(*scanpath[i], *scanpath[i - 1]) for i in range(1, len(scanpath))]
        mean_dist = np.mean(distances)
        std_dist = np.std(distances)
        spatial_threshold = mean_dist + threshold_factor * std_dist

        spatial_outliers = {i for i in range(1, len(scanpath)) if distances[i - 1] > spatial_threshold}
        cleaned_scanpath = [pt for idx, pt in enumerate(scanpath) if idx not in spatial_outliers]

        return cleaned_scanpath, sorted(spatial_outliers)

    def load_dataset(self):
        video_count = 0
        self.images = []
        self.image_scanpath_pairs = []
        self.image_and_scanpath_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

        for video in sorted(os.listdir(self.dataset_root)):
            video_frames_dir = os.path.join(self.dataset_root, video)
            csv_path = os.path.join(self.scan_path_directory, f"{video}.csv")
            if not os.path.exists(csv_path):
                print(f"CSV missing for video {video}, skipping")
                continue
            gaze_df = pd.read_csv(csv_path)
            gaze_mapping = defaultdict(lambda: defaultdict(list))
            for _, row in gaze_df.iterrows():
                frame = int(row['frame'])
                spid = int(row['scanpath_id'])
                gaze_mapping[frame][spid].append((row['y'], row['x']))

            frame_files = []
            for img_file in os.listdir(video_frames_dir):
                img_match = re.match(r'(\d+)\.png$', img_file)
                if not img_match:
                    continue
                frame_num = int(img_match.group(1))
                img_path = os.path.join(video_frames_dir, img_file)
                frame_files.append((frame_num, img_path))

            frame_files.sort(key=lambda x: x[0])

            total_scanpaths = 0
            for frame, img_path in frame_files:
                self.images.append(img_path)
                gaze_points_per_scanid = gaze_mapping.get(frame, {})
                self.image_scanpath_pairs.append((img_path, gaze_points_per_scanid))
                self.image_and_scanpath_dict[video][frame] = gaze_points_per_scanid

                num_scanpaths = len(gaze_points_per_scanid)
                total_scanpaths += num_scanpaths
                #print(f"Video: {video}, Frame: {frame}, Number of scanpaths: {num_scanpaths}")
                #print(f"  Frame path: {img_path}")
            video_count += 1
            if video_count >= 2: 
                break
        return self.image_and_scanpath_dict

    def interpolate_only_outliers(self, original_scanpath, removed_indices):
        if not removed_indices:
            return list(original_scanpath)

        n = len(original_scanpath)
        interpolated_scanpath = list(original_scanpath) 
        removed_set = set(removed_indices)

        for idx in sorted(removed_indices):
            prev_idx = idx - 1
            while prev_idx >= 0 and prev_idx in removed_set:
                prev_idx -= 1

            next_idx = idx + 1
            while next_idx < n and next_idx in removed_set:
                next_idx += 1

            if prev_idx >= 0 and next_idx < n:
                p0 = self.latlon_to_unitvec(*original_scanpath[prev_idx])
                p1 = self.latlon_to_unitvec(*original_scanpath[next_idx])
                t = (idx - prev_idx) / (next_idx - prev_idx)
                interp_vec = self.slerp(p0, p1, t, 1)
                interp_vec /= np.linalg.norm(interp_vec)
                interpolated_scanpath[idx] = self.unitvec_to_latlon(interp_vec)

            elif prev_idx < 0 and next_idx < n:
                interpolated_scanpath[idx] = original_scanpath[next_idx]

            elif prev_idx >= 0 and next_idx >= n:
                interpolated_scanpath[idx] = original_scanpath[prev_idx]
            else:
                interpolated_scanpath[idx] = original_scanpath[idx]
        return interpolated_scanpath

    def process_scanpaths(self, save_folder=None):
        self.cleaned = defaultdict(lambda: defaultdict(dict))
        os.makedirs(save_folder, exist_ok=True)

        for video in self.image_and_scanpath_dict:
            print(f"Processing Video: {video}")
            video_rows = []

            for frame in sorted(self.image_and_scanpath_dict[video]):
                scanpaths_dict = self.image_and_scanpath_dict[video][frame]

                for scan_id, scanpath in scanpaths_dict.items():
                    initial_len = len(scanpath)
                    cleaned, removed_indices = self.remove_outliers(scanpath)
                    if removed_indices:
                        interpolated = self.interpolate_only_outliers(scanpath, removed_indices)
                    else:
                        interpolated = list(scanpath)

                    assert len(interpolated) == initial_len, (f"Length mismatch for video {video} frame {frame} scan {scan_id}: "f"{len(interpolated)} != {initial_len}")
                    '''print(f"Scan-path ID: {scan_id}")
                    print(f"Original scanpath ({initial_len} points):")
                    for i, pt in enumerate(scanpath):
                        print(f"  {i}: {pt}")

                    print(f"Outliers removed at indices: {removed_indices}")

                    print(f"Interpolated scanpath ({len(interpolated)} points, same as original):")
                    for i, pt in enumerate(interpolated):
                        print(f"  {i}: {pt}")'''

                    self.cleaned[video][frame][scan_id] = interpolated
                    for idx, (lat, lon) in enumerate(interpolated):
                        video_rows.append([frame, scan_id, idx, lat, lon])

            csv_path = os.path.join(save_folder, f"{video}.csv")
            with open(csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["frame", "scanpath_id", "point_idx", "y", "x"])
                writer.writerows(video_rows)
data_loader = ScanpathData360("/Users/nagasaithadishetty/Desktop/WEB360/videos/Frames","/Users/nagasaithadishetty/Desktop/Diffusion_Research/Multi_Concept_Code/ScanDMM/demo/output")
data_loader.load_dataset()
data_loader.process_scanpaths(save_folder="/Users/nagasaithadishetty/Desktop/Diffusion_Research/Multi_Concept_Code/ScanDMM/demo/output/interpolated")

CSV missing for video .DS_Store, skipping
CSV missing for video .DS_Store, skipping
Processing Video: 100001
Processing Video: 100002


In [50]:
import os 
import re
import pandas as pd
import numpy as np
from collections import defaultdict
import types
from functools import reduce
from fastdtw import fastdtw
from scipy.cluster.hierarchy import linkage, fcluster, dendrogram
from scipy.spatial.distance import squareform
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

#clustering and final source scan-path for each image

@staticmethod
def great_circle_distance(phi1, lambda1, phi2, lambda2):
        phi1, lambda1, phi2, lambda2 = map(np.radians, [phi1, lambda1, phi2, lambda2])
        delta_lambda = np.abs(lambda1 - lambda2)
        K = np.sin(phi1) * np.sin(phi2) + np.cos(phi1) * np.cos(phi2) * np.cos(delta_lambda)
        K = np.clip(K, -1.0, 1.0)
        return np.arccos(K)

def great_circle_distance_squared(p1, p2):
    return great_circle_distance(p1[0], p1[1], p2[0], p2[1])**2

def fill_delta_mat_dtw(center, s, delta_mat):
    center_len = len(center)
    s_len = len(s)
    slim = delta_mat[:center_len, :s_len]
    for i in range(center_len):
        for j in range(s_len):
            slim[i, j] = great_circle_distance_squared(center[i], s[j])

def squared_DTW(s, t, cost_mat, delta_mat):
    s_len = len(s)
    t_len = len(t)
    
    fill_delta_mat_dtw(s, t, delta_mat)
    cost_mat[0, 0] = delta_mat[0, 0]
    for i in range(1, s_len):
        cost_mat[i, 0] = cost_mat[i - 1, 0] + delta_mat[i, 0]

    for j in range(1, t_len):
        cost_mat[0, j] = cost_mat[0, j - 1] + delta_mat[0, j]

    for i in range(1, s_len):
        for j in range(1, t_len):
            res = min(cost_mat[i - 1, j - 1], cost_mat[i, j - 1], cost_mat[i - 1, j])
            cost_mat[i, j] = res + delta_mat[i, j]
    return cost_mat[s_len - 1, t_len - 1]

def approximate_medoid_index(series, cost_mat, delta_mat):
    if len(series) <= 50:
        indices = range(len(series))
    else:
        indices = np.random.choice(range(len(series)), 50, replace=False)

    medoid_ind = -1
    best_ss = 1e20
    for index_candidate in indices:
        candidate = series[index_candidate]
        ss = sum(map(lambda t: squared_DTW(candidate, t, cost_mat, delta_mat), series))
        if (medoid_ind == -1 or ss < best_ss):
            best_ss = ss
            medoid_ind = index_candidate
    return medoid_ind

def DBA_update(center, series, cost_mat, path_mat, delta_mat):
    options_argmin = [(-1, -1), (0, -1), (-1, 0)]
    updated_center = np.zeros(center.shape)
    n_elements = np.zeros(center.shape[0], dtype=int)
    
    center_length = len(center)
    for s in series:
        s_len = len(s)
        fill_delta_mat_dtw(center, s, delta_mat)
        cost_mat[0, 0] = delta_mat[0, 0]
        path_mat[0, 0] = -1

        for i in range(1, center_length):
            cost_mat[i, 0] = cost_mat[i - 1, 0] + delta_mat[i, 0]
            path_mat[i, 0] = 2 

        for j in range(1, s_len):
            cost_mat[0, j] = cost_mat[0, j - 1] + delta_mat[0, j]
            path_mat[0, j] = 1 

        for i in range(1, center_length):
            for j in range(1, s_len):
                costs = [cost_mat[i - 1, j - 1], cost_mat[i, j - 1], cost_mat[i - 1, j]]
                path_mat[i,j] = np.argmin(costs)
                cost_mat[i, j] = costs[path_mat[i,j]] + delta_mat[i, j]

        i, j = center_length - 1, s_len - 1
        while (path_mat[i, j] != -1):
            updated_center[i, :] += s[j, :]
            n_elements[i] += 1
            move = options_argmin[path_mat[i, j]]
            i += move[0]
            j += move[1]
            
        updated_center[0, :] += s[0, :]
        n_elements[0] += 1

    n_elements[n_elements == 0] = 1
    return np.divide(updated_center, n_elements[:, np.newaxis])

def performDBA(series, max_iterations=300, threshold=0.001):
    if not series:
        return np.array([])
    
    n_series = len(series)
    max_length = reduce(max, map(len, series))
    cost_mat = np.zeros((max_length, max_length))
    delta_mat = np.zeros((max_length, max_length))
    path_mat = np.zeros((max_length, max_length), dtype=np.int8)

    medoid_ind = approximate_medoid_index(series, cost_mat, delta_mat)
    center = series[medoid_ind].copy()

    for i in range(max_iterations):
        prev_center = center.copy()
        center = DBA_update(center, series, cost_mat, path_mat, delta_mat)

        change = np.mean(np.square(center - prev_center))
        if change > threshold:
            break
    return center

class ScanpathData360:
    def __init__(self, dataset_root, scan_path_directory):
        self.dataset_root = dataset_root
        self.scan_path_directory = scan_path_directory
        self.images = []
        self.image_scanpath_pairs = []
        self.image_and_scanpath_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        self.fastdtw_results_by_video_and_frame = defaultdict(lambda: defaultdict(dict))
    
    '''def load_dataset(self):
        print("Loading dataset...")
        video_count = 0
        for video in sorted(os.listdir(self.dataset_root)):
            video_frames_dir = os.path.join(self.dataset_root, video)
            csv_path = os.path.join(self.scan_path_directory, f"{video}.csv")
            if not os.path.exists(csv_path): continue
            gaze_df = pd.read_csv(csv_path)
            gaze_mapping = defaultdict(lambda: defaultdict(list))
            for _, row in gaze_df.iterrows():
                frame = int(row['frame'])
                spid = int(row['scanpath_id'])
                gaze_mapping[frame][spid].append((row['y'], row['x']))

            for img_file in sorted(os.listdir(video_frames_dir)):
                img_match = re.match(r'frame_(\d+)\.png$', img_file)
                if not img_match: continue
                
                frame = int(img_match.group(1))
                img_path = os.path.join(video_frames_dir, img_file)
                
                self.images.append(img_path)
                gaze_points_per_scanid = gaze_mapping[frame]
                self.image_scanpath_pairs.append((img_path, gaze_points_per_scanid))
                self.image_and_scanpath_dict[video][frame] = gaze_points_per_scanid
            video_count += 1
            if video_count >= 2: 
                break
        print(f"Loaded {len(self.image_scanpath_pairs)} image-scanpath pairs from {video_count} video(s).")
        return self.image_and_scanpath_dict'''
    
    def load_dataset(self):
        video_count = 0
        self.images = []
        self.image_scanpath_pairs = []
        self.image_and_scanpath_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

        for video in sorted(os.listdir(self.dataset_root)):
            video_frames_dir = os.path.join(self.dataset_root, video)
            csv_path = os.path.join(self.scan_path_directory, f"{video}.csv")
            if not os.path.exists(csv_path):
                print(f"CSV missing for video {video}, skipping")
                continue
            gaze_df = pd.read_csv(csv_path)
            gaze_mapping = defaultdict(lambda: defaultdict(list))
            for _, row in gaze_df.iterrows():
                frame = int(row['frame'])
                spid = int(row['scanpath_id'])
                gaze_mapping[frame][spid].append((row['y'], row['x']))

            frame_files = []
            for img_file in os.listdir(video_frames_dir):
                img_match = re.match(r'(\d+)\.png$', img_file)
                if not img_match:
                    continue
                frame_num = int(img_match.group(1))
                img_path = os.path.join(video_frames_dir, img_file)
                frame_files.append((frame_num, img_path))

            frame_files.sort(key=lambda x: x[0])
            total_scanpaths = 0
            for frame, img_path in frame_files:
                self.images.append(img_path)
                gaze_points_per_scanid = gaze_mapping.get(frame, {})
                self.image_scanpath_pairs.append((img_path, gaze_points_per_scanid))
                self.image_and_scanpath_dict[video][frame] = gaze_points_per_scanid

                num_scanpaths = len(gaze_points_per_scanid)
                total_scanpaths += num_scanpaths
                #print(f"Video: {video}, Frame: {frame}, Number of scanpaths: {num_scanpaths}")
                #print(f"  Frame path: {img_path}")
            video_count += 1
            if video_count >= 2: 
                break
        return self.image_and_scanpath_dict

    def pairwise_scanpath_fastdtw(self, scanpaths_dict, radius=20):
        scan_ids = list(scanpaths_dict.keys())
        N = len(scan_ids)
        fastdtw_matrix = np.zeros((N, N))

        def point_distance(p1, p2):
            return great_circle_distance(p1[0], p1[1], p2[0], p2[1])

        for i in range(N):
            sp1 = scanpaths_dict[scan_ids[i]]
            for j in range(i + 1, N):
                sp2 = scanpaths_dict[scan_ids[j]]
                dist, _ = fastdtw(sp1, sp2, radius=radius, dist=point_distance)
                fastdtw_matrix[i, j] = dist
                fastdtw_matrix[j, i] = dist 
        return fastdtw_matrix, scan_ids

    def compute_all_frames_fastdtw(self, radius=20):
        print("\nComputing pairwise FastDTW for all frames...")
        for video in self.image_and_scanpath_dict:
            for frame in self.image_and_scanpath_dict[video]:
                scanpaths_dict = self.image_and_scanpath_dict[video][frame]
                fastdtw_matrix, scan_ids = self.pairwise_scanpath_fastdtw(scanpaths_dict, radius=radius)
                self.fastdtw_results_by_video_and_frame[video][frame] = {'matrix': fastdtw_matrix, 'scan_ids': scan_ids}
        print("FastDTW computation complete.")
        return self.fastdtw_results_by_video_and_frame
    
    def cluster_and_save_all_frames_1(self, k, output_dir):
        os.makedirs(output_dir, exist_ok=True)
        for video, frame_dict in self.fastdtw_results_by_video_and_frame.items():
            video_dir = os.path.join(output_dir, video)
            os.makedirs(video_dir, exist_ok=True)
            for frame, results in frame_dict.items():
                condensed_dist_matrix = squareform(results['matrix'])
                scan_ids = results['scan_ids']
                Z = linkage(condensed_dist_matrix, method='complete')
                cluster_labels = fcluster(Z, t=k, criterion='maxclust')

                clusters = defaultdict(list)
                for idx, label in enumerate(cluster_labels):
                    clusters[label].append(scan_ids[idx])

                output_path = os.path.join(video_dir, f"frame_{frame}.pkl")
                with open(output_path, 'wb') as f:
                    pickle.dump({'clusters': clusters, 'linkage_matrix': Z, 'scan_ids': scan_ids}, f)
                print(f"  Saved clustering for {video} frame {frame} to {output_path}")

    def cluster_and_save_all_frames(self, k, output_dir):
        os.makedirs(output_dir, exist_ok=True)
        for video, frame_dict in self.fastdtw_results_by_video_and_frame.items():
            video_dir = os.path.join(output_dir, video)
            os.makedirs(video_dir, exist_ok=True)
            for frame, results in frame_dict.items():
                scan_ids = results['scan_ids']
                if len(scan_ids) < 2:
                    print(f"  Skipping clustering for {video} frame {frame}: not enough scanpaths ({len(scan_ids)})")
                    continue
                condensed_dist_matrix = squareform(results['matrix'])
                Z = linkage(condensed_dist_matrix, method='complete')
                cluster_labels = fcluster(Z, t=k, criterion='maxclust')

                clusters = defaultdict(list)
                for idx, label in enumerate(cluster_labels):
                    clusters[label].append(scan_ids[idx])

                output_path = os.path.join(video_dir, f"frame_{frame}.pkl")
                with open(output_path, 'wb') as f:
                    pickle.dump({'clusters': clusters, 'linkage_matrix': Z, 'scan_ids': scan_ids}, f)
                print(f"  Saved clustering for {video} frame {frame} to {output_path}")

    def cluster_and_save_all_frames_now(self, k, output_dir):
        os.makedirs(output_dir, exist_ok=True)
        for video, frame_dict in self.fastdtw_results_by_video_and_frame.items():
            video_dir = os.path.join(output_dir, video)
            os.makedirs(video_dir, exist_ok=True)

            for frame, results in frame_dict.items():
                scan_ids = results.get('scan_ids', [])
                distance_matrix = results.get('matrix', None)

                # Skip if there are not enough scanpaths
                if len(scan_ids) < 2:
                    print(f"  Skipping {video} frame {frame}: not enough scanpaths ({len(scan_ids)})")
                    continue

                # Skip if distance matrix is missing or empty
                if distance_matrix is None or np.size(distance_matrix) == 0:
                    print(f"  Skipping {video} frame {frame}: distance matrix is empty")
                    continue

                try:
                    # Convert to condensed distance matrix
                    condensed_dist_matrix = squareform(distance_matrix)
                    # Compute linkage
                    Z = linkage(condensed_dist_matrix, method='complete')
                    # Assign cluster labels
                    cluster_labels = fcluster(Z, t=k, criterion='maxclust')

                    # Group scan_ids by cluster
                    clusters = defaultdict(list)
                    for idx, label in enumerate(cluster_labels):
                        clusters[label].append(scan_ids[idx])

                    # Save results
                    output_path = os.path.join(video_dir, f"frame_{frame}.pkl")
                    with open(output_path, 'wb') as f:
                        pickle.dump({'clusters': clusters, 'linkage_matrix': Z, 'scan_ids': scan_ids}, f)

                    print(f"  Saved clustering for {video} frame {frame} to {output_path}")

                except ValueError as e:
                    print(f"  Error clustering {video} frame {frame}: {e}")

    def display_fastdtw_results(self, k=30, output_dir='clustering_results'):
        print("\nDisplaying FastDTW Results and Clustering:")
        for video in sorted(os.listdir(os.path.join(output_dir))):
            video_dir = os.path.join(output_dir, video)
            if not os.path.isdir(video_dir):
                continue
            print(f"Video: {video}")
            for frame_file in sorted(os.listdir(video_dir)):
                if frame_file.endswith('.pkl'):
                    frame = int(re.match(r'frame_(\d+)\.pkl$', frame_file).group(1))
                    output_path = os.path.join(video_dir, frame_file)
                    
                    with open(output_path, 'rb') as f:
                        saved_data = pickle.load(f)
                    clusters = saved_data['clusters']
                    Z = saved_data['linkage_matrix']
                    scan_ids = saved_data['scan_ids']

                    print(f"  Frame {frame}:")
                    print("\n=== Scanpaths grouped by cluster ===")
                    for cluster_id, member_scanids in sorted(clusters.items()):
                        print(f"\nCluster {cluster_id} (size={len(member_scanids)}):")
                        for scan_id in member_scanids:
                            print(f"  Scanpath {scan_id}")

                    plt.figure(figsize=(10, 5))
                    dendrogram(Z, truncate_mode='lastp', p=k, leaf_rotation=90., leaf_font_size=12., show_contracted=True)
                    plt.title(f"Hierarchical Clustering Dendrogram for {video}, Frame {frame} (truncated)")
                    plt.xlabel("Cluster Index")
                    plt.ylabel("Distance")
                    plt.show()

    def compute_and_save_final_source_scanpath(self, clustering_dir, output_csv_dir):
        for video in sorted(os.listdir(clustering_dir)):
            video_dir = os.path.join(clustering_dir, video)
            print(f"\n--- Processing Video: {video} ---")
            output_dir = os.path.join(output_csv_dir, video)
            os.makedirs(output_dir, exist_ok=True)

            frame_files = sorted([f for f in os.listdir(video_dir) if f.endswith('.pkl')],key=lambda x: int(x.split('_')[1].split('.')[0]))
            total_frames = len(frame_files)
            frame_csv_saved = 0
            for frame_idx, frame_file in enumerate(frame_files):
                frame_match = re.match(r'frame_(\d+)\.pkl$', frame_file)
                if not frame_match:
                    continue
                frame = int(frame_match.group(1))
                try:
                    with open(os.path.join(video_dir, frame_file), 'rb') as f:
                        data = pickle.load(f)
                        clusters = data['clusters']
                except Exception as e:
                    print(f"    Frame {frame}: Could not load pickle file. Error: {e}. Skipping.")
                    continue

                cluster_average_scanpaths = []
                for cluster_id, member_scanids in sorted(clusters.items()):
                    cluster_scanpaths = [np.array(self.image_and_scanpath_dict[video][frame][spid]) for spid in member_scanids if spid in self.image_and_scanpath_dict[video][frame]]

                    if len(cluster_scanpaths) > 1:
                        avg_scanpath = performDBA(cluster_scanpaths)
                        cluster_average_scanpaths.append(avg_scanpath)
                    elif len(cluster_scanpaths) == 1:
                        cluster_average_scanpaths.append(cluster_scanpaths[0])

                source_scanpath = None
                if len(cluster_average_scanpaths) > 1:
                    source_scanpath = performDBA(cluster_average_scanpaths)
                elif len(cluster_average_scanpaths) == 1:
                    source_scanpath = cluster_average_scanpaths[0]

                if source_scanpath is not None and source_scanpath.size > 0:
                    df_scanpath = pd.DataFrame(source_scanpath, columns=['y', 'x'])
                    csv_path = os.path.join(output_dir, f'Frame_{frame}.csv')
                    df_scanpath.to_csv(csv_path, index=False)
                    frame_csv_saved += 1
                    print(f"    Saved scanpath for Frame {frame} to CSV")
                else:
                    print(f"    Frame {frame_idx+1}/{total_frames} ({frame}): No source scanpath computed.")

dataset_root = "/Users/nagasaithadishetty/Desktop/WEB360/videos/Frames"
scan_path_directory = "/Users/nagasaithadishetty/Desktop/Diffusion_Research/Multi_Concept_Code/ScanDMM/demo/output/interpolated"
OUTPUT_DIR = "/Users/nagasaithadishetty/Desktop/Diffusion_Research/Multi_Concept_Code/ScanDMM/demo/output/clustering_results"
Final_source_DIR = os.path.join(scan_path_directory, 'source_scanpaths')

data_loader = ScanpathData360(dataset_root, scan_path_directory)
data_loader.load_dataset()
data_loader.compute_all_frames_fastdtw(radius=20)
data_loader.cluster_and_save_all_frames(k=30, output_dir=OUTPUT_DIR)
#data_loader.display_fastdtw_results(k=30, output_dir=OUTPUT_DIR)  
data_loader.compute_and_save_final_source_scanpath(clustering_dir=OUTPUT_DIR,output_csv_dir=Final_source_DIR)

  img_match = re.match(r'frame_(\d+)\.png$', img_file)


CSV missing for video .DS_Store, skipping

Computing pairwise FastDTW for all frames...
FastDTW computation complete.
  Saved clustering for 100001 frame 1 to /Users/nagasaithadishetty/Desktop/Diffusion_Research/Multi_Concept_Code/ScanDMM/demo/output/clustering_results/100001/frame_1.pkl
  Saved clustering for 100001 frame 2 to /Users/nagasaithadishetty/Desktop/Diffusion_Research/Multi_Concept_Code/ScanDMM/demo/output/clustering_results/100001/frame_2.pkl
  Saved clustering for 100001 frame 3 to /Users/nagasaithadishetty/Desktop/Diffusion_Research/Multi_Concept_Code/ScanDMM/demo/output/clustering_results/100001/frame_3.pkl
  Saved clustering for 100001 frame 4 to /Users/nagasaithadishetty/Desktop/Diffusion_Research/Multi_Concept_Code/ScanDMM/demo/output/clustering_results/100001/frame_4.pkl
  Saved clustering for 100001 frame 5 to /Users/nagasaithadishetty/Desktop/Diffusion_Research/Multi_Concept_Code/ScanDMM/demo/output/clustering_results/100001/frame_5.pkl
  Saved clustering for 1000

In [None]:
import os
import pandas as pd
import re
import cv2
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm

source_scanpaths_dir = "/media/scratch/Trauma_Detection/code/ScanDMM/demo/output/source_scanpaths_csv_initial"
frame_directory = "/media/scratch/datasets/WEB360/videos_512x1024x100/Frames/"
caption_path = "/media/scratch/datasets/WEB360/WEB360_360TF_train.csv"

CaptionData = pd.read_csv(caption_path)
CaptionData = CaptionData[CaptionData['videoid'].astype(str).isin(os.listdir(frame_directory))]

def load_scanpath(scanpath_dir, frame_number):
    frame_csv_path = os.path.join(scanpath_dir, f"Frame_{frame_number}.csv")
    if not os.path.exists(frame_csv_path):
        return None, frame_csv_path
    df = pd.read_csv(frame_csv_path)
    return df[['x', 'y']].values.tolist(), frame_csv_path

def plot_single_scanpath_on_image(scanpath, img_path, save_path=None, caption=None, img_height=256, img_width=512):

    image = cv2.imread(img_path)
    image = cv2.resize(image, (img_width, img_height))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.imshow(image)
    ax.axis('off')

    normalized = all(0 <= x <= 1 and 0 <= y <= 1 for x, y in scanpath)
    points_x = [(p[0] * img_width if normalized else p[0]) for p in scanpath]
    points_y = [(p[1] * img_height if normalized else p[1]) for p in scanpath]
    colors = cm.rainbow(np.linspace(0, 1, len(scanpath)))

    previous_point = None
    for idx, (x, y, c) in enumerate(zip(points_x, points_y, colors)):
        if previous_point is not None:
            if abs(previous_point[0] - x) < (img_width / 2):
                ax.plot([previous_point[0], x], [previous_point[1], y], color='blue', linewidth=2, alpha=0.4)
            else:
                h_diff = (y - previous_point[1]) / 2
                if x > previous_point[0]:
                    ax.plot([previous_point[0], 0], [previous_point[1], previous_point[1] + h_diff],
                            color='blue', linewidth=2, alpha=0.4)
                    ax.plot([img_width, x], [previous_point[1] + h_diff, y],
                            color='blue', linewidth=2, alpha=0.4)
                else:
                    ax.plot([previous_point[0], img_width], [previous_point[1], previous_point[1] + h_diff],
                            color='blue', linewidth=2, alpha=0.4)
                    ax.plot([0, x], [previous_point[1] + h_diff, y],
                            color='blue', linewidth=2, alpha=0.4)
        previous_point = (x, y)
        ax.plot(x, y, marker='o', markersize=8, color=c, alpha=0.9)
        ax.text(x + 4, y - 4, str(idx + 1), color='yellow', fontsize=9)
    plt.figtext(0.5, 0.01, f"📜 {caption}", wrap=True, ha='center', fontsize=10)
    plt.show()


def frame_with_scanpath(source_scanpaths_dir, frame_directory, caption_df):
    video_data={}
    for video_id in sorted(os.listdir(frame_directory)):
        video_data[video_id] =[]
        video_frame_dir = os.path.join(frame_directory, video_id)
        scanpath_video_dir = os.path.join(source_scanpaths_dir, video_id)
        caption_row = caption_df[caption_df['videoid'].astype(str) == video_id]
        caption_text = caption_row['name'].values[0] if not caption_row.empty else None

        print(f"\n🎞️ Video: {video_id}")
        frame_files = sorted([f for f in os.listdir(video_frame_dir) if f.lower().endswith('.png')])
        for frame_file in frame_files:
            frame_number = int(re.search(r'frame_(\d+)', frame_file).group(1).lstrip('0') or '0')
            print(f"  Frame: {frame_file} (Number: {frame_number})")
            image_path = os.path.join(video_frame_dir, frame_file)
            scanpath, csv_path = load_scanpath(scanpath_video_dir, frame_number)
            print(f"  Frame Path : {image_path};  Scanpath CSV: {csv_path}")
            if scanpath is None:
                continue
            #plot_single_scanpath_on_image(scanpath, image_path, caption=caption_text)
            video_data[video_id].append({"frame_file": frame_file,"frame_number": frame_number,"image_path": image_path,"caption": caption_text,"scanpath": scanpath})
            print(video_data[video_id])

frame_with_scanpath(source_scanpaths_dir, frame_directory, CaptionData)
