In [5]:
import os
import cv2
import torch
import numpy as np
from PIL import Image
from ultralytics import YOLO  # Assuming YOLOv8 is being used
from sam2.build_sam import build_sam2_video_predictor
from torch.utils.checkpoint import checkpoint
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler

# Paths and Directories
video_path = './clip 2.mp4'
output_frames_dir = './video_frames'
output_video_path = 'segmented_video.mp4'

# Create directory to store frames
os.makedirs(output_frames_dir, exist_ok=True)

# Load Video and Extract Frames
cap = cv2.VideoCapture(video_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))

frame_names = []
for i in range(frame_count):
    ret, frame = cap.read()
    if not ret:
        break
    frame_name = f'{i:05d}.jpg'
    cv2.imwrite(os.path.join(output_frames_dir, frame_name), frame)
    frame_names.append(frame_name)
cap.release()


# Custom Dataset for Batch Processing
class VideoFrameDataset(torch.utils.data.Dataset):
    def __init__(self, frame_names, frame_dir,target_size=(640, 640)):
        self.frame_names = frame_names
        self.frame_dir = frame_dir
        self.target_size = target_size  # (Height, Width)
    def __len__(self):
        return len(self.frame_names)

    def __getitem__(self, idx):
        frame_path = os.path.join(self.frame_dir, self.frame_names[idx])
        frame = Image.open(frame_path).convert('RGB')
        frame = frame.resize(self.target_size, Image.BILINEAR)
        frame_np = np.array(frame)
        # Convert to (C, H, W)
        frame_np = frame_np.transpose(2, 0, 1)
        frame_tensor = torch.from_numpy(frame_np).float() / 255.0  # Normalize to [0,1]
        return frame_tensor, idx

# Load Dataset
dataset = VideoFrameDataset(frame_names, output_frames_dir)
batch_size = 8 # Adjust batch size based on GPU memory
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Load SAM 2 Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam2_checkpoint = "./segment-anything-2/checkpoints/sam2_hiera_small.pt"
model_cfg = "sam2_hiera_s.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

# Initialize YOLO model for object detection
yolo_model = YOLO('yolov8s.pt').to(device)  # Use the best YOLO model variant

# Initialize SAM 2 State
inference_state = predictor.init_state(video_path=output_frames_dir)
predictor.reset_state(inference_state)

ann_obj_id = 1  # Initialize object ID



    
# Process frames in batches
video_segments = {}
for batch_frames, batch_indices in dataloader:
    batch_frames = batch_frames.to(device)

    # YOLO Object Detection with Gradient Checkpointing
    with autocast():
        results = yolo_model(batch_frames, imgsz=640)  # Ensure img size matches target_size

    for i, frame_idx in enumerate(batch_indices):
        detected_boxes = results[i].boxes.xyxy.cpu().numpy().astype(np.float32)  # Shape: (num_boxes, 4)

        # Add detected boxes to SAM 2 for segmentation
        for single_box in detected_boxes:
            # SAM expects boxes in [x1, y1, x2, y2] format
            _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                inference_state=inference_state,
                frame_idx=int(frame_idx),
                obj_id=ann_obj_id,
                box=single_box.tolist(),  # Convert to list if necessary
            )
            ann_obj_id += 1

        # Propagate masks in video after adding new objects
        video_segments[int(frame_idx)] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }
          # Clear cache to free up memory
    torch.cuda.empty_cache()

# Save the Segmented Video
# Assuming all frames are resized to target_size
height, width = dataset.target_size
fourcc = cv2.VideoWriter_fourcc(*'mp4v')

out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

for out_frame_idx in range(frame_count):
    # Load the resized frame
    frame_path = os.path.join(output_frames_dir, frame_names[out_frame_idx])
    frame = cv2.imread(frame_path)
    frame = cv2.resize(frame, (width, height))  # Ensure size matches target_size

    if out_frame_idx in video_segments:
        for out_obj_id, out_mask in video_segments[out_frame_idx].items():
            # Create a color overlay for the mask
            color = np.array([0, 255, 0], dtype=np.uint8)  # Green mask
            mask = (out_mask * 255).astype(np.uint8)
            mask_rgb = cv2.merge([mask, mask, mask])
            colored_mask = cv2.bitwise_and(color, color, mask=mask)
            # Blend the mask with the frame
            frame = cv2.addWeighted(frame, 1.0, colored_mask, 0.5, 0)
    out.write(frame)

out.release()

print(f'Segmented video saved to {output_video_path}')


Downloading https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8s.pt to 'yolov8s.pt'...


100%|██████████| 21.5M/21.5M [00:00<00:00, 32.3MB/s]
frame loading (JPEG): 100%|██████████| 2395/2395 [02:09<00:00, 18.46it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 28.07 GiB. GPU 0 has a total capacity of 8.00 GiB of which 0 bytes is free. Of the allocated memory 29.46 GiB is allocated by PyTorch, and 198.65 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import os
import cv2
import torch
import numpy as np
from PIL import Image
from ultralytics import YOLO  # Ensure you have ultralytics installed
from sam2.build_sam import build_sam2_video_predictor
from torch.utils.checkpoint import checkpoint
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm  # For progress bars

# Paths and Directories
video_path = './clip 1.mp4'
output_frames_dir = './video_frames'
output_video_path = 'segmented_video.mp4'

# Create directory to store frames
os.makedirs(output_frames_dir, exist_ok=True)

# Load Video and Extract Frames
def extract_frames(video_path, output_dir):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    frame_names = []
    for i in tqdm(range(frame_count), desc="Extracting frames"):
        ret, frame = cap.read()
        if not ret:
            break
        frame_name = f'{i:05d}.jpg'
        cv2.imwrite(os.path.join(output_frames_dir, frame_name), frame)
        frame_names.append(frame_name)
    cap.release()
    return frame_names, fps

frame_names, fps = extract_frames(video_path, output_frames_dir)

# Custom Dataset for Batch Processing
class VideoFrameDataset(Dataset):
    def __init__(self, frame_names, frame_dir, target_size=(320, 320)):
        self.frame_names = frame_names
        self.frame_dir = frame_dir
        self.target_size = target_size  # (Width, Height)
    
    def __len__(self):
        return len(self.frame_names)
    
    def __getitem__(self, idx):
        frame_path = os.path.join(self.frame_dir, self.frame_names[idx])
        frame = Image.open(frame_path).convert('RGB')
        frame = frame.resize(self.target_size, Image.BILINEAR)
        frame_np = np.array(frame)
        # Convert to (C, H, W)
        frame_np = frame_np.transpose(2, 0, 1)
        frame_tensor = torch.from_numpy(frame_np).float() / 255.0  # Normalize to [0,1]
        return frame_tensor, idx

# Load Dataset and DataLoader
dataset = VideoFrameDataset(frame_names, output_frames_dir)
batch_size = 2  # Reduced batch size to minimize memory usage
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
# Check for CUDA and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load SAM 2 Model and move it to GPU if available
sam2_checkpoint = "./segment-anything-2/checkpoints/sam2_hiera_small.pt"  # Ensure this path is correct
model_cfg = "sam2_hiera_s.yaml"  # Ensure this config file is present
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

# Initialize YOLO model for object detection and move to GPU if available
yolo_model = YOLO('yolov8s.pt').to(device)  # Using a smaller YOLO model variant

# Initialize SAM 2 State with offloading frames to CPU to save GPU memory
inference_state = predictor.init_state(
    video_path=output_frames_dir,
    offload_video_to_cpu=True  # Offload video frames to CPU
)
predictor.reset_state(inference_state)

ann_obj_id = 1  # Initialize object ID

# Dictionary to store segmentation masks
video_segments = {}

# Process frames in batches
for batch_frames, batch_indices in tqdm(dataloader, desc="Processing batches"):
    batch_frames = batch_frames.to(device, non_blocking=True)
    
    with torch.no_grad():  # Disable gradient calculations for inference
        with autocast():  # Enable mixed precision
            # Perform YOLO object detection on the batch
            results = yolo_model(batch_frames, imgsz=320)  # Ensure img size matches target_size
    
    for i, frame_idx in enumerate(batch_indices):
        detected_boxes = results[i].boxes.xyxy.cpu().numpy().astype(np.float32)  # Shape: (num_boxes, 4)
        
        # Add detected boxes to SAM 2 for segmentation
        for single_box in detected_boxes:
            # SAM expects boxes in [x1, y1, x2, y2] format
            try:
                _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                    inference_state=inference_state,
                    frame_idx=int(frame_idx),
                    obj_id=ann_obj_id,
                    box=single_box.tolist(),  # Convert to list if necessary
                )
                ann_obj_id += 1
            except Exception as e:
                print(f"Error processing frame {frame_idx} with box {single_box}: {e}")
                continue
    
        # Propagate masks in video after adding new objects
        try:
            video_segments[int(frame_idx)] = {
                out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                for i, out_obj_id in enumerate(out_obj_ids)
            }
        except Exception as e:
            print(f"Error propagating masks for frame {frame_idx}: {e}")
            continue
    
    # Clear cache to free up memory
    torch.cuda.empty_cache()

# Save the Segmented Video
def save_segmented_video(output_path, frame_names, video_segments, target_size=(320, 320)):
    height, width = target_size
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    for idx in tqdm(range(len(frame_names)), desc="Saving video"):
        frame_path = os.path.join(output_frames_dir, frame_names[idx])
        frame = cv2.imread(frame_path)
        frame = cv2.resize(frame, (width, height))  # Ensure size matches target_size
    
        if idx in video_segments:
            for out_obj_id, out_mask in video_segments[idx].items():
                # Create a color overlay for the mask
                color = np.array([0, 255, 0], dtype=np.uint8)  # Green mask
                mask = (out_mask * 255).astype(np.uint8)
                mask_rgb = cv2.merge([mask, mask, mask])
                colored_mask = cv2.bitwise_and(color, color, mask=mask)
                # Blend the mask with the frame
                frame = cv2.addWeighted(frame, 1.0, colored_mask, 0.5, 0)
        out.write(frame)
    
    out.release()
    print(f'Segmented video saved to {output_path}')

save_segmented_video(output_video_path, frame_names, video_segments, target_size=(320, 320))


Extracting frames: 100%|██████████| 153/153 [00:04<00:00, 34.05it/s]


Using device: cuda


frame loading (JPEG): 100%|██████████| 476/476 [00:28<00:00, 16.78it/s]
Processing batches:   0%|          | 0/77 [00:00<?, ?it/s]


0: 320x320 2 cars, 4.3ms
1: 320x320 1 car, 4.3ms
Speed: 1.0ms preprocess, 4.3ms inference, 1.5ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:   1%|▏         | 1/77 [00:00<00:49,  1.54it/s]


0: 320x320 1 car, 7.4ms
1: 320x320 1 car, 7.4ms
Speed: 0.0ms preprocess, 7.4ms inference, 2.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:   3%|▎         | 2/77 [00:01<00:39,  1.91it/s]


0: 320x320 2 cars, 6.5ms
1: 320x320 2 cars, 6.5ms
Speed: 0.0ms preprocess, 6.5ms inference, 2.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:   4%|▍         | 3/77 [00:01<00:36,  2.02it/s]


0: 320x320 2 cars, 1 traffic light, 7.2ms
1: 320x320 2 cars, 7.2ms
Speed: 0.0ms preprocess, 7.2ms inference, 2.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:   5%|▌         | 4/77 [00:02<00:35,  2.04it/s]


0: 320x320 3 cars, 1 traffic light, 5.5ms
1: 320x320 3 cars, 5.5ms
Speed: 0.0ms preprocess, 5.5ms inference, 1.5ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:   6%|▋         | 5/77 [00:02<00:36,  1.99it/s]


0: 320x320 4 cars, 4.9ms
1: 320x320 3 cars, 4.9ms
Speed: 0.0ms preprocess, 4.9ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:   8%|▊         | 6/77 [00:03<00:36,  1.95it/s]


0: 320x320 1 person, 3 cars, 4.8ms
1: 320x320 2 cars, 4.8ms
Speed: 0.0ms preprocess, 4.8ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:   9%|▉         | 7/77 [00:03<00:36,  1.92it/s]


0: 320x320 1 person, 1 car, 7.2ms
1: 320x320 1 car, 7.2ms
Speed: 0.0ms preprocess, 7.2ms inference, 2.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  10%|█         | 8/77 [00:04<00:35,  1.92it/s]


0: 320x320 1 car, 5.2ms
1: 320x320 1 person, 1 car, 5.2ms
Speed: 0.0ms preprocess, 5.2ms inference, 2.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  12%|█▏        | 9/77 [00:04<00:35,  1.94it/s]


0: 320x320 1 person, 1 car, 5.5ms
1: 320x320 2 persons, 1 car, 5.5ms
Speed: 0.0ms preprocess, 5.5ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  13%|█▎        | 10/77 [00:05<00:35,  1.90it/s]


0: 320x320 2 persons, 1 car, 7.5ms
1: 320x320 3 persons, 1 car, 7.5ms
Speed: 0.0ms preprocess, 7.5ms inference, 2.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  14%|█▍        | 11/77 [00:05<00:36,  1.82it/s]


0: 320x320 3 persons, 1 car, 5.2ms
1: 320x320 1 person, 1 car, 5.2ms
Speed: 0.0ms preprocess, 5.2ms inference, 1.5ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  16%|█▌        | 12/77 [00:06<00:36,  1.78it/s]


0: 320x320 1 person, 1 car, 5.5ms
1: 320x320 1 car, 5.5ms
Speed: 0.0ms preprocess, 5.5ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  17%|█▋        | 13/77 [00:06<00:35,  1.80it/s]


0: 320x320 3 persons, 3 cars, 5.0ms
1: 320x320 3 persons, 2 cars, 5.0ms
Speed: 0.0ms preprocess, 5.0ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  18%|█▊        | 14/77 [00:07<00:37,  1.68it/s]


0: 320x320 3 persons, 1 car, 7.5ms
1: 320x320 3 persons, 3 cars, 7.5ms
Speed: 0.0ms preprocess, 7.5ms inference, 1.5ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  19%|█▉        | 15/77 [00:08<00:38,  1.60it/s]


0: 320x320 3 persons, 3 cars, 4.0ms
1: 320x320 3 persons, 1 car, 4.0ms
Speed: 0.0ms preprocess, 4.0ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  21%|██        | 16/77 [00:09<00:39,  1.53it/s]


0: 320x320 3 persons, 1 car, 7.7ms
1: 320x320 3 persons, 1 car, 7.7ms
Speed: 0.0ms preprocess, 7.7ms inference, 1.5ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  22%|██▏       | 17/77 [00:09<00:40,  1.50it/s]


0: 320x320 2 persons, 1 car, 4.5ms
1: 320x320 2 persons, 1 car, 4.5ms
Speed: 0.0ms preprocess, 4.5ms inference, 0.5ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  23%|██▎       | 18/77 [00:10<00:40,  1.47it/s]


0: 320x320 2 persons, 2 cars, 6.0ms
1: 320x320 2 persons, 2 cars, 6.0ms
Speed: 0.0ms preprocess, 6.0ms inference, 1.5ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  25%|██▍       | 19/77 [00:11<00:40,  1.43it/s]


0: 320x320 2 persons, 2 cars, 7.0ms
1: 320x320 1 person, 1 car, 7.0ms
Speed: 0.0ms preprocess, 7.0ms inference, 1.5ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  26%|██▌       | 20/77 [00:11<00:40,  1.40it/s]


0: 320x320 1 person, 1 car, 7.0ms
1: 320x320 1 person, 7.0ms
Speed: 0.0ms preprocess, 7.0ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  27%|██▋       | 21/77 [00:12<00:39,  1.42it/s]


0: 320x320 1 person, 5.0ms
1: 320x320 1 person, 5.0ms
Speed: 0.0ms preprocess, 5.0ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  29%|██▊       | 22/77 [00:13<00:37,  1.45it/s]


0: 320x320 1 person, 4.5ms
1: 320x320 1 person, 4.5ms
Speed: 0.0ms preprocess, 4.5ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  30%|██▉       | 23/77 [00:13<00:36,  1.47it/s]


0: 320x320 1 person, 7.0ms
1: 320x320 4 persons, 7.0ms
Speed: 0.0ms preprocess, 7.0ms inference, 2.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  31%|███       | 24/77 [00:14<00:37,  1.43it/s]


0: 320x320 3 persons, 4.0ms
1: 320x320 2 persons, 1 car, 4.0ms
Speed: 0.0ms preprocess, 4.0ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  32%|███▏      | 25/77 [00:15<00:37,  1.37it/s]


0: 320x320 1 person, 1 car, 6.0ms
1: 320x320 1 person, 2 cars, 6.0ms
Speed: 0.0ms preprocess, 6.0ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  34%|███▍      | 26/77 [00:16<00:39,  1.29it/s]


0: 320x320 2 persons, 2 cars, 8.5ms
1: 320x320 2 persons, 4 cars, 1 fire hydrant, 8.5ms
Speed: 0.0ms preprocess, 8.5ms inference, 2.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  35%|███▌      | 27/77 [00:19<01:11,  1.43s/it]


0: 320x320 3 persons, 4 cars, 1 fire hydrant, 4.5ms
1: 320x320 2 persons, 2 cars, 4.5ms
Speed: 0.0ms preprocess, 4.5ms inference, 2.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  36%|███▋      | 28/77 [00:23<01:48,  2.21s/it]


0: 320x320 2 persons, 2 cars, 4.0ms
1: 320x320 1 person, 2 cars, 4.0ms
Speed: 0.0ms preprocess, 4.0ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  38%|███▊      | 29/77 [00:26<02:01,  2.53s/it]


0: 320x320 1 person, 2 cars, 7.0ms
1: 320x320 2 cars, 7.0ms
Speed: 0.0ms preprocess, 7.0ms inference, 1.5ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  39%|███▉      | 30/77 [00:30<02:19,  2.96s/it]


0: 320x320 1 car, 6.0ms
1: 320x320 1 car, 6.0ms
Speed: 0.0ms preprocess, 6.0ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)


Processing batches:  40%|████      | 31/77 [00:33<02:15,  2.94s/it]


0: 320x320 1 car, 4.5ms
1: 320x320 1 car, 4.5ms
Speed: 0.0ms preprocess, 4.5ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)
Error processing frame 63 with box [     285.25       213.5      308.75         248]: CUDA out of memory. Tried to allocate 1.85 GiB. GPU 0 has a total capacity of 8.00 GiB of which 2.10 GiB is free. Of the allocated memory 2.32 GiB is allocated by PyTorch, and 2.43 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


Processing batches:  42%|████▏     | 32/77 [00:36<02:10,  2.89s/it]


0: 320x320 1 car, 5.0ms
1: 320x320 1 car, 5.0ms
Speed: 0.0ms preprocess, 5.0ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 320)
Error processing frame 64 with box [     285.25       213.5      308.75         248]: CUDA out of memory. Tried to allocate 1.85 GiB. GPU 0 has a total capacity of 8.00 GiB of which 2.10 GiB is free. Of the allocated memory 2.32 GiB is allocated by PyTorch, and 2.43 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Error processing frame 65 with box [        290       214.5         310         250]: CUDA out of memory. Tried to allocate 1.85 GiB. GPU 0 has a total capacity of 8.00 GiB of which 2.10 GiB is free. Of the allocated memory 2.32 GiB is allocated by PyTorch, and 2.43 GiB is reserved by PyTorch but una

In [11]:
import os
import cv2
import torch
import numpy as np
from PIL import Image
from ultralytics import YOLO  # Ensure you have ultralytics installed
from sam2.build_sam import build_sam2_video_predictor
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm  # For progress bars

# Paths and Directories
video_path = './clip 1.mp4'
output_frames_dir = './video_frames'
output_video_path = 'segmented_video.mp4'

# Create directory to store frames
os.makedirs(output_frames_dir, exist_ok=True)

# Load Video and Extract Frames
def extract_frames(video_path, output_dir):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    frame_names = []
    for i in tqdm(range(frame_count), desc="Extracting frames"):
        ret, frame = cap.read()
        if not ret:
            break
        frame_name = f'{i:05d}.jpg'
        cv2.imwrite(os.path.join(output_frames_dir, frame_name), frame)
        frame_names.append(frame_name)
    cap.release()
    return frame_names, fps

frame_names, fps = extract_frames(video_path, output_frames_dir)

# Custom Dataset for Batch Processing
class VideoFrameDataset(Dataset):
    def __init__(self, frame_names, frame_dir, target_size=(320, 320)):
        self.frame_names = frame_names
        self.frame_dir = frame_dir
        self.target_size = target_size  # (Width, Height)
    
    def __len__(self):
        return len(self.frame_names)
    
    def __getitem__(self, idx):
        frame_path = os.path.join(self.frame_dir, self.frame_names[idx])
        frame = Image.open(frame_path).convert('RGB')
        frame = frame.resize(self.target_size, Image.BILINEAR)
        frame_np = np.array(frame)
        # Convert to (C, H, W)
        frame_np = frame_np.transpose(2, 0, 1)
        frame_tensor = torch.from_numpy(frame_np).float() / 255.0  # Normalize to [0,1]
        return frame_tensor, idx

# Load Dataset and DataLoader
dataset = VideoFrameDataset(frame_names, output_frames_dir)
batch_size = 2  # Reduced batch size to minimize memory usage
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# Check for CUDA and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load SAM 2 Model and move it to GPU if available
sam2_checkpoint = "./segment-anything-2/checkpoints/sam2_hiera_small.pt"  # Ensure this path is correct
model_cfg = "sam2_hiera_s.yaml"  # Ensure this config file is present
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

# Initialize YOLO model for object detection and move to GPU if available
yolo_model = YOLO('yolov8s.pt').to(device)  # Using a smaller YOLO model variant

# Initialize SAM 2 State with offloading frames to CPU to save GPU memory
inference_state = predictor.init_state(
    video_path=output_frames_dir,
    offload_video_to_cpu=True  # Offload video frames to CPU
)
predictor.reset_state(inference_state)

ann_obj_id = 1  # Initialize object ID

# Dictionary to store segmentation masks
video_segments = {}

# Process frames in batches
scaler = GradScaler()  # For mixed precision training
for batch_frames, batch_indices in tqdm(dataloader, desc="Processing batches"):
    batch_frames = batch_frames.to(device, non_blocking=True)
    
    with torch.no_grad():  # Disable gradient calculations for inference
        with autocast():  # Enable mixed precision
            # Perform YOLO object detection on the batch
            results = yolo_model(batch_frames, imgsz=320)  # Ensure img size matches target_size
    
    for i, frame_idx in enumerate(batch_indices):
        detected_boxes = results[i].boxes.xyxy.cpu().numpy().astype(np.float32)  # Shape: (num_boxes, 4)
        
        # Add detected boxes to SAM 2 for segmentation
        for single_box in detected_boxes:
            # SAM expects boxes in [x1, y1, x2, y2] format
            try:
                _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                    inference_state=inference_state,
                    frame_idx=int(frame_idx),
                    obj_id=ann_obj_id,
                    box=single_box.tolist(),  # Convert to list if necessary
                )
                ann_obj_id += 1
            except Exception as e:
                print(f"Error processing frame {frame_idx} with box {single_box}: {e}")
                continue
    
        # Propagate masks in video after adding new objects
        try:
            video_segments[int(frame_idx)] = {
                out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                for i, out_obj_id in enumerate(out_obj_ids)
            }
        except Exception as e:
            print(f"Error propagating masks for frame {frame_idx}: {e}")
            continue
    
    # Clear cache to free up memory
    torch.cuda.empty_cache()

# Save the Segmented Video
def save_segmented_video(output_path, frame_names, video_segments, target_size=(320, 320)):
    height, width = target_size
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    for idx in tqdm(range(len(frame_names)), desc="Saving video"):
        frame_path = os.path.join(output_frames_dir, frame_names[idx])
        frame = cv2.imread(frame_path)
        frame = cv2.resize(frame, (width, height))  # Ensure size matches target_size
    
        if idx in video_segments:
            for out_obj_id, out_mask in video_segments[idx].items():
                # Create a color overlay for the mask
                color = np.array([0, 255, 0], dtype=np.uint8)  # Green mask
                mask = (out_mask * 255).astype(np.uint8)
                mask_rgb = cv2.merge([mask, mask, mask])
                colored_mask = cv2.bitwise_and(color, color, mask=mask)
                # Blend the mask with the frame
                frame = cv2.addWeighted(frame, 1.0, colored_mask, 0.5, 0)
        out.write(frame)
    
    out.release()
    print(f'Segmented video saved to {output_path}')

save_segmented_video(output_video_path, frame_names, video_segments, target_size=(320, 320))



RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
