# Object Tracking Notebook
This notebook combines approaches from Lab2_A, Lab2_C, and Lab2_E to perform object tracking instead of voxel tracking.

## 1. Import Required Libraries
Import all necessary libraries and utility functions.

In [1]:
# Dependencies and Imports
# Install required packages if not already installed
!pip install --upgrade pip
!pip install torch==2.4.0+cu121 torchvision==0.19.0+cu121 --index-url https://download.pytorch.org/whl/cu121
!pip install transformers==4.44.0 huggingface-hub==0.24.0 pillow numpy opencv-python open3d ipympl rerun-sdk[notebook]==0.24.1

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch

# SAM and CLIP imports
from lab_utils.model_loaders import load_owlv2_model, load_sam_model
from lab_utils.model_loaders import load_sam_model, load_clip_model

# Utility imports (if available)
from lab_utils.data_utils import *
from lab_utils.detection_utils import *
from lab_utils.visualization_utils import *


Collecting pip
  Using cached pip-25.2-py3-none-any.whl (1.8 MB)
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.0.1
    Uninstalling pip-23.0.1:
      Successfully uninstalled pip-23.0.1
Successfully installed pip-25.2
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.4.0+cu121
  Using cached https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl (799.1 MB)
Collecting torchvision==0.19.0+cu121
  Using cached https://download.pytorch.org/whl/cu121/torchvision-0.19.0%2Bcu121-cp310-cp310-linux_x86_64.whl (7.1 MB)
Collecting filelock (from torch==2.4.0+cu121)
  Using cached https://download.pytorch.org/whl/filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)
Collecting sympy (from torch==2.4.0+cu121)
  Using cached https://download.pytorch.org/whl/sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch==2.4.0+cu121)
  Using cached https://download.pytor



Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## 2. Load Data
Load scene data, images, and annotations.

In [None]:
# Load data (update paths and logic as needed)
import os
scene_id = '40753679'
data_path = f'ARKitScenesData/{scene_id}/40753679_frames/lowres_wide/'

# Get all image file paths (assuming .png images)
images = [os.path.join(data_path, fname) for fname in os.listdir(data_path) if fname.endswith('.png')]
images.sort()  # Ensure temporal order

# If you have annotation loading logic, add it here
# annotations = load_annotations(data_path)


In [None]:
# Test: Check loaded images
print(f"Loaded {len(images)} images.")
print("First 3 image paths:", images[:3])

## 3. Object Detection
Detect objects in the scene using methods from previous labs.

In [None]:
# 3.1 Load SAM and CLIP models

print("Loading models...")
sam_model, sam_processor, device = load_sam_model(model_size='base')
clip_model, clip_processor, _ = load_clip_model(device=device)

In [None]:
class TrackedObject:
    def __init__(self, object_id, class_name, initial_position, initial_embedding):
        self.id = object_id
        self.class_name = class_name
        self.positions = [initial_position]  # Historical 3D positions
        self.embeddings = [initial_embedding]  # Historical embeddings
        self.last_seen = 0  # Frame counter
        self.is_static = True  # Flag for static/dynamic objects
        
    def update(self, new_position, new_embedding, frame_idx):
        # Update position and embedding
        # Check if object moved significantly
        # Update tracking status

In [None]:
# Test: Check SAM and CLIP model loading
print("Predictor type:", type(predictor))
print("CLIP model type:", type(clip_model))
print("Preprocess type:", type(preprocess))
print("Device:", device)

In [None]:
# 3.2 Detect objects in images using SAM + CLIP
def detect_objects_sam_clip(images, predictor, clip_model, preprocess, device, text_prompts):
    import numpy as np
    from PIL import Image
    import torch
    detected_objects = []
    text_tokens = clip.tokenize(text_prompts).to(device)
    for img_path in images:
        image = np.array(Image.open(img_path))
        predictor.set_image(image)
        masks, _, _ = predictor.predict(
            point_coords=None,
            point_labels=None,
            multimask_output=True
        )
        for mask in masks:
            y_indices, x_indices = np.where(mask)
            if len(y_indices) == 0 or len(x_indices) == 0:
                continue
            y_min, y_max = y_indices.min(), y_indices.max()
            x_min, x_max = x_indices.min(), x_indices.max()
            crop = image[y_min:y_max, x_min:x_max]
            pil_crop = Image.fromarray(crop)
            clip_input = preprocess(pil_crop).unsqueeze(0).to(device)
            with torch.no_grad():
                image_features = clip_model.encode_image(clip_input)
                text_features = clip_model.encode_text(text_tokens)
                logits_per_image = (100.0 * image_features @ text_features.T).softmax(dim=-1)
                pred_class = text_prompts[logits_per_image.argmax().item()]
            detected_objects.append({
                "image": img_path,
                "bbox": [x_min, y_min, x_max, y_max],
                "mask": mask,
                "class": pred_class
            })
    return detected_objects

# Example usage:
text_prompts = ["sofa", "shelf", "table", "chair"]
detected_objects = detect_objects_sam_clip(images, predictor, clip_model, preprocess, device, text_prompts)


In [None]:
# Test: Check detection output
print(f"Detected {len(detected_objects)} objects.")
if detected_objects:
    print("First detected object:", detected_objects[0])

## 4. Object Tracking Logic
Implement object tracking across frames.

In [None]:
# 4.1 Define object tracking function (IoU-based, no voxels)
def track_objects(detected_objects, iou_threshold=0.5):
    """
    Track objects across frames using IoU and class consistency.
    """
    # Sort detected objects by image path to ensure temporal order
    detected_objects.sort(key=lambda x: x["image"])
    # Group detections by frame
    frames = {}
    for obj in detected_objects:
        img_path = obj["image"]
        if img_path not in frames:
            frames[img_path] = []
        frames[img_path].append(obj)
    frame_list = sorted(frames.keys())
    detections_by_frame = [frames[f] for f in frame_list]
    next_track_id = 0
    tracked_objects = []
    # First frame: assign new track IDs
    if detections_by_frame:
        for detection in detections_by_frame[0]:
            detection["track_id"] = next_track_id
            tracked_objects.append(detection)
            next_track_id += 1
    # Subsequent frames: associate detections
    for frame_idx in range(1, len(detections_by_frame)):
        current_detections = detections_by_frame[frame_idx]
        prev_frame_tracks = [obj for obj in tracked_objects if obj["image"] == frame_list[frame_idx-1]]
        for detection in current_detections:
            best_iou = iou_threshold
            best_match = -1
            for track in prev_frame_tracks:
                if track["class"] != detection["class"]:
                    continue
                # IoU calculation
                det_bbox = detection["bbox"]
                track_bbox = track["bbox"]
                x1 = max(det_bbox[0], track_bbox[0])
                y1 = max(det_bbox[1], track_bbox[1])
                x2 = min(det_bbox[2], track_bbox[2])
                y2 = min(det_bbox[3], track_bbox[3])
                if x2 < x1 or y2 < y1:
                    iou = 0
                else:
                    intersection = (x2 - x1) * (y2 - y1)
                    det_area = (det_bbox[2] - det_bbox[0]) * (det_bbox[3] - det_bbox[1])
                    track_area = (track_bbox[2] - track_bbox[0]) * (track_bbox[3] - track_bbox[1])
                    union = det_area + track_area - intersection
                    iou = intersection / union if union > 0 else 0
                if iou > best_iou:
                    best_iou = iou
                    best_match = track["track_id"]
            if best_match >= 0:
                detection["track_id"] = best_match
            else:
                detection["track_id"] = next_track_id
                next_track_id += 1
            tracked_objects.append(detection)
    return tracked_objects

# 4.2 Apply tracking to detected objects
tracked_objects = track_objects(detected_objects)


In [None]:
# Test: Check tracking output
print(f"Tracked {len(tracked_objects)} objects.")
if tracked_objects:
    print("First tracked object:", tracked_objects[0])

## 5. Visualization
Visualize tracked objects over time.

In [None]:
# Visualize tracking results
# visualize_tracking(tracked_objects, images)

## 6. Evaluation
Evaluate tracking performance using ground truth.

In [None]:
# Evaluate tracking
# evaluation_metrics = evaluate_tracking(tracked_objects, ground_truth)