In [None]:
import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
import matplotlib.pyplot as plt
import os
import json
from deepface import DeepFace
from scipy.spatial.distance import cosine
from ultralytics import YOLO
from sklearn.cluster import DBSCAN  # new import for clustering

# -----------------------------
# Configuration and Constants
# -----------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Transform for gaze detection model
gaze_transform = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ToTensor(),
])

# Threshold for matching face embeddings across frames (used initially)
EMBEDDING_THRESHOLD = 0.6
# Confidence threshold for YOLO face detection
YOLO_CONF_THRESHOLD = 0.5

# Flag to show YOLO detections per frame visually
SHOW_YOLO_DETECTIONS = True

# -----------------------------
# Custom JSON Encoder
# -----------------------------
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.integer)):
            return int(obj)
        elif isinstance(obj, (np.floating)):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyEncoder, self).default(obj)

# -----------------------------
# YOLO Face Detection Function
# -----------------------------
def detect_faces_yolo(img_array, yolo_face_model, conf_threshold=YOLO_CONF_THRESHOLD):
    """
    Detect faces in an image using YOLO.
    
    Args:
        img_array: Input image as a NumPy array.
        yolo_face_model: Loaded YOLO model.
        conf_threshold: Confidence threshold.
    
    Returns:
        List of bounding boxes [x1, y1, x2, y2] in pixel coordinates.
    """
    results = yolo_face_model(img_array)
    face_boxes = []
    if results and results[0].boxes is not None:
        boxes = results[0].boxes.xyxy.cpu().numpy()  # [N,4]
        confs = results[0].boxes.conf.cpu().numpy()    # [N]
        for box, conf in zip(boxes, confs):
            if conf >= conf_threshold:
                face_boxes.append([int(x) for x in box])
    return face_boxes

# -----------------------------
# (Optional) Known Faces Loading (Profiles)
# -----------------------------
def load_known_faces(image_paths, yolo_face_model):
    """
    Load known faces (profiles) using YOLO for detection and DeepFace for embeddings.
    
    Args:
        image_paths: Dictionary {name: image_path}
        yolo_face_model: Loaded YOLO model.
    
    Returns:
        Dictionary {name: embedding} for known people.
    """
    known_faces = {}
    for name, path in image_paths.items():
        try:
            img = cv2.imread(path)
            if img is None:
                print(f"Could not load image for {name} at {path}")
                continue
            detected = detect_faces_yolo(img, yolo_face_model)
            if detected:
                x1, y1, x2, y2 = detected[0]
                face_roi = img[y1:y2, x1:x2]
                rep = DeepFace.represent(face_roi, model_name='Facenet512',
                                         detector_backend='skip', enforce_detection=False)
                embedding = np.array(rep[0]['embedding'])
                known_faces[name] = embedding
                print(f"Successfully loaded face for {name}")
            else:
                print(f"No face detected for {name} in {path}")
        except Exception as e:
            print(f"Error loading face for {name}: {e}")
    return known_faces

# -----------------------------
# Face Embedding and Matching
# -----------------------------
def get_face_embedding(face_roi):
    """Extract face embedding using DeepFace (detection skipped)."""
    try:
        rep = DeepFace.represent(face_roi, model_name='Facenet512',
                                 detector_backend='skip', enforce_detection=False)
        return np.array(rep[0]['embedding'])
    except Exception as e:
        print(f"Embedding error: {e}")
        return None

def match_known_face(face_emb, known_faces):
    """
    Match a face embedding to the most similar known face.
    
    Args:
        face_emb: Detected face embedding.
        known_faces: Dictionary {name: embedding}.
    
    Returns:
        Tuple (name, similarity_score)
    """
    best_match = None
    best_score = float('inf')
    for kname, known_emb in known_faces.items():
        dist = cosine(known_emb, face_emb)
        if dist < best_score:
            best_score = dist
            best_match = kname
    return best_match, best_score

def analyze_emotions(face_roi):
    """
    Analyze emotions for the given face ROI using DeepFace.
    """
    try:
        analysis = DeepFace.analyze(face_roi, actions=['emotion'],
                                    detector_backend='skip', enforce_detection=False)
        return analysis[0]['dominant_emotion']
    except Exception as e:
        print(f"Emotion analysis error: {e}")
        return "Unknown"

# -----------------------------
# Visualization Function
# -----------------------------
def visualize_all(pil_image, heatmaps, bboxes, inout_scores=None, emotions=None, names=None, inout_thresh=0.5):
    """
    Create a visualization image with bounding boxes, gaze points, and header text.
    
    Args:
        pil_image: PIL Image.
        heatmaps: Heatmaps for each face.
        bboxes: Normalized bounding boxes (values in [0, 1]).
        inout_scores: In-out scores for gaze (optional).
        emotions: Emotion strings (optional).
        names: Face names (optional).
        inout_thresh: Threshold for considering gaze as "looking at camera".
    
    Returns:
        An annotated PIL Image.
    """
    colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']
    width, height = pil_image.size
    header_height = 100
    output_img = Image.new("RGBA", (width, height + header_height), (255,255,255,255))
    output_img.paste(pil_image, (0, header_height))
    draw = ImageDraw.Draw(output_img)
    draw.rectangle([0, 0, width, header_height], fill=(240,240,240,255))
    draw.line([(0, header_height), (width, header_height)], fill=(100,100,100,255), width=2)
    header_text = []
    for i, norm_box in enumerate(bboxes):
        xmin, ymin, xmax, ymax = norm_box
        xmin_px = xmin * width
        ymin_px = ymin * height + header_height
        xmax_px = xmax * width
        ymax_px = ymax * height + header_height
        color = colors[i % len(colors)]
        line_width = max(1, int(min(width, height) * 0.01))
        draw.rectangle([xmin_px, ymin_px, xmax_px, ymax_px], outline=color, width=line_width)
        info_list = []
        if names is not None and i < len(names):
            info_list.append(f"{names[i]}")
        else:
            info_list.append(f"Person {i+1}")
        if emotions is not None and i < len(emotions):
            info_list.append(f"Emotion: {emotions[i]}")
        if inout_scores is not None and i < len(inout_scores):
            score_val = float(inout_scores[i].item() if torch.is_tensor(inout_scores[i]) else inout_scores[i])
            look_str = "Looking at camera" if score_val > inout_thresh else "Not looking at camera"
            info_list.append(f"{look_str} ({score_val:.2f})")
        header_text.append((info_list, color))
        if heatmaps is not None and i < len(heatmaps):
            do_draw = True
            if inout_scores is not None and i < len(inout_scores):
                score_val = float(inout_scores[i].item() if torch.is_tensor(inout_scores[i]) else inout_scores[i])
                do_draw = score_val > inout_thresh
            if do_draw:
                heat_np = heatmaps[i].detach().cpu().numpy() if torch.is_tensor(heatmaps[i]) else heatmaps[i]
                max_idx = np.unravel_index(np.argmax(heat_np), heat_np.shape)
                gaze_y, gaze_x = max_idx
                gaze_x = gaze_x / heat_np.shape[1] * width
                gaze_y = gaze_y / heat_np.shape[0] * height + header_height
                center_x = ((xmin + xmax) / 2) * width
                center_y = ((ymin + ymax) / 2) * height + header_height
                radius = max(1, int(min(width, height) * 0.01))
                draw.ellipse([(gaze_x - radius, gaze_y - radius), (gaze_x + radius, gaze_y + radius)], fill=color)
                draw.line([(center_x, center_y), (gaze_x, gaze_y)], fill=color, width=line_width)
    font_size = 24
    try:
        font = ImageFont.truetype("Arial", font_size)
    except:
        font = ImageFont.load_default()
    y_offset = 10
    for info, col in header_text:
        box_size = 15
        draw.rectangle([10, y_offset, 10 + box_size, y_offset + box_size], fill=col)
        draw.text((10 + box_size + 5, y_offset), " | ".join(info), fill="black", font=font)
        y_offset += font_size + 5
    title = f"Analysis for {len(bboxes)} detected face(s)"
    title_font_size = 32
    try:
        title_font = ImageFont.truetype("Arial", title_font_size)
    except:
        title_font = font
    title_width = title_font.getlength(title) if title_font else len(title) * title_font_size * 0.6
    title_x = (width - title_width) / 2
    draw.text((title_x, 0), title, fill="black", font=title_font)
    return output_img

# -----------------------------
# Visualize YOLO Detections Function
# -----------------------------
def visualize_yolo_detections(img_array, yolo_face_model, save_path=None):
    """
    Use the YOLO model's built-in plotting method to create an annotated image.
    
    Args:
        img_array: Input image as a NumPy array.
        yolo_face_model: Loaded YOLO model.
        save_path: Optional path to save the annotated image.
    
    Returns:
        Annotated image as a NumPy array.
    """
    results = yolo_face_model(img_array)
    annotated_img = results[0].plot()  # Get annotated image from the first result
    if save_path:
        cv2.imwrite(save_path, annotated_img)
    return annotated_img

# -----------------------------
# Function to Reclustering Profiles Using DBSCAN
# -----------------------------
def recluster_profiles(profiles, eps=0.2):
    """
    Recluster the collected profile embeddings using DBSCAN to separate distinct individuals.
    
    Args:
        profiles: Dictionary mapping profile IDs to dicts that contain at least an "embedding" and "frames_seen".
        eps: The epsilon parameter for DBSCAN (in cosine distance).
    
    Returns:
        new_profiles: Dictionary with new profile IDs and aggregated frame appearances.
        cluster_map: Dictionary mapping old profile IDs to new cluster labels.
    """
    old_ids = []
    embeddings = []
    for pid, prof in profiles.items():
        old_ids.append(pid)
        embeddings.append(prof["embedding"])
    embeddings = np.array(embeddings)
    db = DBSCAN(eps=eps, min_samples=1, metric='cosine').fit(embeddings)
    labels = db.labels_
    cluster_map = {old: label for old, label in zip(old_ids, labels)}
    new_profiles = {}
    for old_pid, label in cluster_map.items():
        if label in new_profiles:
            new_profiles[label]["frames_seen"].extend(profiles[old_pid]["frames_seen"])
        else:
            new_profiles[label] = {
                "embedding": profiles[old_pid]["embedding"],
                "name": f"Person {label + 1}",
                "frames_seen": profiles[old_pid]["frames_seen"].copy()
            }
    return new_profiles, cluster_map

# -----------------------------
# Create JSON for LLM Analysis (without id field)
# -----------------------------
def create_llm_input(results, output_path="llm_analysis_input.json"):
    """
    Convert analysis results to a structured JSON file without including an 'id' field for persons.
    
    Args:
        results: Analysis result dictionary.
        output_path: Path to save the JSON file.
        
    Returns:
        The structured data dictionary.
    """
    llm_data = {
        "session_summary": {
            "total_frames": len([k for k in results.keys() if k not in ['profiles', 'visualizations']]),
            "people_detected": {}
        },
        "frames": []
    }
    profiles = results.get('profiles', {})
    for pid, prof in profiles.items():
        # In session summary, we keep only the person name and frames_seen.
        llm_data["session_summary"]["people_detected"][str(pid)] = {
            "name": prof['name'],
            "frame_appearances": prof['frames_seen']
        }
    for frame_path, frame_data in results.items():
        if frame_path in ['profiles', 'visualizations']:
            continue
        entry = {"frame_path": frame_path, "people": []}
        for face in frame_data.get('faces', []):
            # Note: Remove the "id" field from each person.
            person = {
                "name": face.get('name', "Unknown"),
                "emotion": face.get('emotion', "Unknown"),
                "position": {
                    "x1": int(face['bbox'][0]),
                    "y1": int(face['bbox'][1]),
                    "x2": int(face['bbox'][2]),
                    "y2": int(face['bbox'][3])
                }
            }
            gaze_info = face.get('gaze', {})
            if gaze_info:
                person["gaze"] = {
                    "looking_at_camera": gaze_info.get("looking_at_camera", False),
                    "confidence": float(gaze_info.get("inout_score", 0))
                }
                heat = gaze_info.get("heatmap")
                if heat:
                    heat_np = np.array(heat)
                    max_idx = np.unravel_index(np.argmax(heat_np), heat_np.shape)
                    norm_x = float(max_idx[1]) / heat_np.shape[1]
                    norm_y = float(max_idx[0]) / heat_np.shape[0]
                    person["gaze"]["target"] = {"x": norm_x, "y": norm_y}
            entry["people"].append(person)
        if entry["people"]:
            summaries = []
            for person in entry["people"]:
                summ = f"{person['name']} shows {person['emotion']} emotion"
                if "gaze" in person:
                    if person["gaze"].get("looking_at_camera", False):
                        summ += " and is looking at the camera"
                    else:
                        summ += " and is not looking at the camera"
                summaries.append(summ)
            entry["natural_language_summary"] = ". ".join(summaries) + "."
        else:
            entry["natural_language_summary"] = "No people detected in this frame."
        llm_data["frames"].append(entry)
    with open(output_path, 'w') as f:
        json.dump(llm_data, f, indent=2, cls=NumpyEncoder)
    print(f"LLM analysis input saved to {output_path}")
    preview = json.dumps(llm_data, indent=2, cls=NumpyEncoder)
    print(preview[:1000] + "...\n" if len(preview) > 1000 else preview)
    return llm_data

# -----------------------------
# Main Analysis Function (per frame)
# -----------------------------
def analyze_frames(image_paths, gaze_model, yolo_face_model, known_faces=None, output_dir=None, llm_output_path="llm_analysis_input.json"):
    profiles = {}
    results = {}
    img_tensors = []
    norm_boxes_all = []
    frame_faces_all = []
    print(f"Processing {len(image_paths)} frames...")
    for idx, path in enumerate(image_paths):
        print(f"Processing frame {idx+1}/{len(image_paths)}: {path}")
        try:
            pil_image = Image.open(path).convert("RGB")
        except Exception as e:
            print(f"Error opening {path}: {e}")
            continue
        width, height = pil_image.size
        np_img = np.array(pil_image)
        cv_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
        
        # --- YOLO Detection Step ---
        detected = detect_faces_yolo(np_img, yolo_face_model)
        print(f"YOLO detections for frame {idx+1} ({path}):")
        if detected:
            for j, bbox in enumerate(detected):
                print(f"  Detection {j+1}: Bounding box = {bbox}")
        else:
            print("  No faces detected by YOLO.")
        
        if SHOW_YOLO_DETECTIONS:
            annotated_img = visualize_yolo_detections(np_img, yolo_face_model)
            plt.figure(figsize=(8, 6))
            plt.imshow(cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB))
            plt.title(f"YOLO Detections for Frame {idx+1}")
            plt.axis("off")
            plt.show()
        
        results[path] = {"faces": []}
        face_list = []
        norm_list = []
        used_pids = set()
        if detected:
            for det in detected:
                x1, y1, x2, y2 = det
                face_roi = cv_img[y1:y2, x1:x2]
                if face_roi.size == 0:
                    continue
                emb = get_face_embedding(face_roi)
                if emb is None:
                    continue
                emotion = analyze_emotions(face_roi)
                ident = None
                sim_score = None
                if known_faces:
                    ident, sim_score = match_known_face(emb, known_faces)
                    print(f"Frame {idx}: Matched face with {ident} (score: {sim_score:.4f})")
                pid = None
                # Only match profiles not already used in this frame.
                for p, prof in profiles.items():
                    if p in used_pids:
                        continue
                    if cosine(prof['embedding'], emb) < EMBEDDING_THRESHOLD:
                        pid = p
                        break
                if pid is None:
                    pid = len(profiles) + 1
                    profiles[pid] = {"embedding": emb, "name": ident if ident else f"Person {pid}", "frames_seen": []}
                profiles[pid]["frames_seen"].append(idx)
                used_pids.add(pid)
                if ident and profiles[pid]["name"].startswith("Person "):
                    profiles[pid]["name"] = ident
                face_data = {
                    "bbox": (x1, y1, x2, y2),
                    "profile_id": pid,
                    "identity": ident,
                    "similarity_score": sim_score,
                    "emotion": emotion,
                    "name": profiles[pid]["name"]
                }
                face_list.append(face_data)
                norm_box = np.array([x1/width, y1/height, x2/width, y2/height])
                norm_list.append(norm_box)
        else:
            norm_list = []
        norm_boxes_all.append(norm_list)
        frame_faces_all.append(face_list)
        tensor_img = gaze_transform(pil_image).unsqueeze(0).to(device)
        img_tensors.append(tensor_img)
    
    gaze_out = None
    if img_tensors and any(len(b) > 0 for b in norm_boxes_all):
        img_batch = torch.cat(img_tensors, dim=0)
        inp = {"images": img_batch, "bboxes": norm_boxes_all}
        print("Running gaze detection model...")
        with torch.no_grad():
            gaze_out = gaze_model(inp)
    
    visuals = {}
    for f_idx, path in enumerate(image_paths):
        faces = frame_faces_all[f_idx]
        results[path]["faces"] = []
        if not faces:
            continue
        try:
            pil_image = Image.open(path).convert("RGB")
        except:
            continue
        frame_norm_boxes = []
        frame_names = []
        frame_emotions = []
        frame_inout = []
        for face in faces:
            x1, y1, x2, y2 = face["bbox"]
            frame_norm_box = (x1/width, y1/height, x2/width, y2/height)
            frame_norm_boxes.append(frame_norm_box)
            frame_names.append(face["name"])
            frame_emotions.append(face["emotion"])
            complete_face = {
                "bbox": face["bbox"],
                "profile_id": face["profile_id"],
                "name": face["name"],
                "similarity_score": face.get("similarity_score"),
                "emotion": face["emotion"],
                "gaze": {}
            }
            if gaze_out:
                face_idx = faces.index(face)
                if "inout" in gaze_out and f_idx < len(gaze_out["inout"]) and face_idx < len(gaze_out["inout"][f_idx]):
                    inout_score = gaze_out["inout"][f_idx][face_idx].item()
                    frame_inout.append(inout_score)
                    complete_face["gaze"]["inout_score"] = inout_score
                    complete_face["gaze"]["looking_at_camera"] = inout_score > 0.5
                else:
                    frame_inout.append(None)
                if "heatmap" in gaze_out and f_idx < len(gaze_out["heatmap"]) and face_idx < len(gaze_out["heatmap"][f_idx]):
                    hm = gaze_out["heatmap"][f_idx][face_idx].cpu().numpy()
                    complete_face["gaze"]["heatmap"] = hm.tolist()
            results[path]["faces"].append(complete_face)
        if gaze_out and "heatmap" in gaze_out and f_idx < len(gaze_out["heatmap"]):
            frame_hm = gaze_out["heatmap"][f_idx]
            vis_img = visualize_all(
                pil_image,
                frame_hm,
                frame_norm_boxes,
                frame_inout,
                frame_emotions,
                frame_names
            )
            visuals[path] = vis_img
            if output_dir:
                os.makedirs(output_dir, exist_ok=True)
                out_path = os.path.join(output_dir, f"viz_{os.path.basename(path)}")
                vis_img.save(out_path)
                print(f"Saved visualization for {path} to {out_path}")
    
    # --- Partition profiles into multi-frame and single-frame detections ---
    multi_profiles = {pid: prof for pid, prof in profiles.items() if len(set(prof["frames_seen"])) > 1}
    single_profiles = {pid: prof for pid, prof in profiles.items() if len(set(prof["frames_seen"])) == 1}
    
    # --- Recluster multi-frame profiles using DBSCAN ---
    if multi_profiles:
        new_multi_profiles, cluster_map_multi = recluster_profiles(multi_profiles, eps=0.2)
    else:
        new_multi_profiles, cluster_map_multi = {}, {}
    
    # For single-frame detections, assign new unique IDs starting after multi clusters.
    offset = len(new_multi_profiles)
    new_single_profiles = {}
    for i, (pid, prof) in enumerate(sorted(single_profiles.items())):
        new_id = offset + i + 1
        new_single_profiles[new_id] = {"name": prof["name"], "frames_seen": prof["frames_seen"]}
    
    # Build a unified mapping: for multi, map old pid -> new id, for single use sorted order.
    all_mapping = {}
    for old_pid in multi_profiles:
        all_mapping[old_pid] = cluster_map_multi[old_pid] + 1
    sorted_single = sorted(single_profiles.items(), key=lambda x: x[0])
    for i, (old_pid, prof) in enumerate(sorted_single):
        new_id = offset + i + 1
        all_mapping[old_pid] = new_id
    
    # Update each face in results with new profile id from all_mapping.
    for frame_path, data in results.items():
        if frame_path in ['profiles', 'visualizations']:
            continue
        for face in data.get("faces", []):
            old_pid = face["profile_id"]
            if old_pid in all_mapping:
                face["profile_id"] = all_mapping[old_pid]
    
    # Build the final profiles dictionary.
    combined_profiles = {}
    for label, p in new_multi_profiles.items():
        combined_profiles[label + 1] = {"name": p["name"], "frames_seen": sorted(list(set(p["frames_seen"])))}
    for new_id, p in new_single_profiles.items():
        combined_profiles[new_id] = {"name": p["name"], "frames_seen": sorted(list(set(p["frames_seen"])))}
    results["profiles"] = combined_profiles
    results["visualizations"] = visuals
    print("Analysis complete!")
    llm_data = create_llm_input(results, llm_output_path)
    return results, llm_data

# -----------------------------
# Display Analysis Results
# -----------------------------
def display_results(results):
    visual_out = results.get("visualizations", {})
    for path, vis_img in visual_out.items():
        plt.figure(figsize=(12, 8))
        plt.imshow(vis_img)
        plt.title(f"Analysis for {os.path.basename(path)}")
        plt.axis("off")
        plt.show()
    profs = results.get("profiles", {})
    if profs:
        print("\nProfile Summary:")
        for pid, prof in profs.items():
            print(f"Profile {pid} ({prof['name']}): Appeared in {len(prof['frames_seen'])} frames")

# -----------------------------
# Main Execution Block
# -----------------------------
if __name__ == "__main__":
    image_paths = [
        "frame_pbm_0000.png", "frame_pbm_0001.png", "frame_pbm_0002.png", "frame_pbm_0003.png", "frame_pbm_0004.png",
        "frame_pbm_0005.png", "frame_pbm_0006.png", "frame_pbm_0007.png", "frame_pbm_0008.png", "frame_pbm_0009.png",
        "frame_pbm_0010.png", "frame_pbm_0011.png", "frame_pbm_0012.png", "frame_pbm_0013.png", "frame_pbm_0014.png",
        "frame_pbm_0015.png", "frame_pbm_0016.png", "frame_pbm_0017.png", "frame_pbm_0018.png", "frame_pbm_0019.png",
    ]
    # Load gaze detection model
    gaze_model, _ = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout')
    gaze_model.eval()
    gaze_model.to('cpu')
    # Load YOLO face detection model (replace with your YOLO–face model if available)
    yolo_model_path = "yolov8n.pt"
    print(f"Loading YOLO model from {yolo_model_path} on {device}...")
    try:
        yolo_face_model = YOLO(yolo_model_path)
        yolo_face_model.to(device)
    except Exception as e:
        print(f"Error loading YOLO model: {e}")
        exit(1)
    # For this run, we set known_faces to None (thus no known faces are used)
    known_faces = None
    # Analyze frames
    results, llm_data = analyze_frames(
        image_paths,
        gaze_model,
        yolo_face_model,
        known_faces,
        output_dir="output",
        llm_output_path="llm_analysis_input.json"
    )
    # Display visual results
    display_results(results)
    print("\nLLM data has been created and saved to llm_analysis_input.json")
    print("This data can now be sent to an LLM for analysis.")
