In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

model_path = "/content/drive/MyDrive/trained_yolo_models/best.pt"  
model = YOLO(model_path)
print(f"YOLOv8 model loaded from: {model_path}")
print(f"Model classes: {model.names}")

sam_checkpoint_path = "/content/segment-anything/sam_vit_h.pth"

if not os.path.exists(sam_checkpoint_path):
    print("Downloading SAM checkpoint...")
    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -O /content/sam_vit_h_4b8939.pth
    print("SAM checkpoint downloaded")

# Load SAM model
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
sam.to(device=device)

# Create SAM predictor
predictor = SamPredictor(sam)

print(f"SAM model loaded on device: {device}")

# =============================================================================
# STEP 1: Test Wall Detection on a Sample Image
# =============================================================================

# Let's test on a sample from your validation set first
import glob

# Get a test image from your dataset
test_images = glob.glob("/content/obj (37).jpeg")
if not test_images:
    test_images = glob.glob("/content/obj (37).jpeg")

if test_images:
    test_image_path = test_images[0]  # Take first image
    print(f"🖼️ Testing on: {test_image_path}")
else:
    print("❌ No test images found. Upload one manually:")
    from google.colab import files
    uploaded = files.upload()
    test_image_path = list(uploaded.keys())[0]

# Load and display image
image = cv2.imread(test_image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(image_rgb)
plt.title("Original Image")
plt.axis('off')

print(f"📏 Image shape: {image.shape}")

# =============================================================================
# STEP 2: Run Wall Detection
# =============================================================================

print("🔍 Running wall detection...")

# Run YOLOv8 prediction
results = model.predict(test_image_path, conf=0.3, save=False, verbose=False)

# Extract results
detections = []
wall_detections = []

if results[0].boxes is not None:
    boxes = results[0].boxes.xyxy.cpu().numpy()
    scores = results[0].boxes.conf.cpu().numpy()
    classes = results[0].boxes.cls.cpu().numpy()
    
    for i in range(len(boxes)):
        class_id = int(classes[i])
        class_name = model.names[class_id]
        detection = {
            'bbox': boxes[i],
            'score': scores[i],
            'class_id': class_id,
            'class_name': class_name
        }
        detections.append(detection)
        
        # Collect wall detections specifically
        if class_name == 'wall':
            wall_detections.append(detection)

print(f"✅ Total detections: {len(detections)}")
print(f"🏠 Wall detections: {len(wall_detections)}")

# Print all detections
for det in detections:
    print(f"   {det['class_name']}: {det['score']:.3f}")

# =============================================================================
# FIXED: Visualization with Proper Color Format
# =============================================================================

# Visualize all detections (FIXED COLOR ISSUE)
image_with_boxes = image_rgb.copy()

for i, det in enumerate(detections):
    x1, y1, x2, y2 = det['bbox'].astype(int)
    
    # FIXED: Use proper color format for OpenCV (BGR format as tuples)
    if det['class_name'] == 'wall':
        color = (255, 0, 0)  # Red for walls
        thickness = 3
    else:
        # Different colors for different classes
        colors = [(0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), 
                 (0, 255, 255), (128, 128, 0), (128, 0, 128), (0, 128, 128)]
        color = colors[det['class_id'] % len(colors)]
        thickness = 2
    
    # Draw bounding box
    cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, thickness)
    
    # Add label with better formatting
    label = f"{det['class_name']}: {det['score']:.2f}"
    cv2.putText(image_with_boxes, label, (x1, y1-10), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)

plt.subplot(1, 3, 2)
plt.imshow(image_with_boxes)
plt.title(f"All Detections ({len(detections)} objects)")
plt.axis('off')

# =============================================================================
# STEP 3: Focus on Wall Segmentation with SAM
# =============================================================================

print("🎯 Processing walls with SAM...")

# Set image for SAM
predictor.set_image(image_rgb)

wall_segments = []
if wall_detections:
    for i, wall_det in enumerate(wall_detections):
        print(f"   Processing wall {i+1}/{len(wall_detections)}")
        
        # Use wall bounding box as prompt for SAM
        x1, y1, x2, y2 = wall_det['bbox'].astype(int)
        input_box = np.array([[x1, y1, x2, y2]])
        
        # Get SAM segmentation
        masks, scores, logits = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],
            multimask_output=False,
        )
        
        wall_segments.append({
            'mask': masks[0],
            'sam_score': scores[0],
            'yolo_score': wall_det['score'],
            'bbox': wall_det['bbox'],
            'wall_id': i+1
        })

print(f"✅ Processed {len(wall_segments)} wall segments")

# Visualize wall segments
if wall_segments:
    # Create combined wall mask
    combined_wall_mask = np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=np.uint8)
    
    for i, segment in enumerate(wall_segments):
        combined_wall_mask[segment['mask']] = (i + 1) * 80
    
    plt.subplot(1, 3, 3)
    plt.imshow(combined_wall_mask, cmap='viridis')
    plt.title(f"Wall Segments ({len(wall_segments)} walls)")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # ADDED: Create detailed overlay visualization
    plt.figure(figsize=(12, 8))
    
    # Show original with wall masks overlay
    overlay = image_rgb.copy()
    colors = [(255, 100, 100), (100, 255, 100), (100, 100, 255), (255, 255, 100)]
    
    for i, segment in enumerate(wall_segments):
        mask = segment['mask']
        color = colors[i % len(colors)]
        overlay[mask] = overlay[mask] * 0.6 + np.array(color) * 0.4
        
        # Add wall number on the mask
        mask_coords = np.where(mask)
        if len(mask_coords[0]) > 0:
            center_y, center_x = np.mean(mask_coords[0]), np.mean(mask_coords[1])
            cv2.putText(overlay, f"Wall {i+1}", (int(center_x)-30, int(center_y)), 
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
    
    plt.imshow(overlay)
    plt.title("Walls with Segmentation Masks")
    plt.axis('off')
    plt.show()
    
else:
    plt.subplot(1, 3, 3)
    plt.text(0.5, 0.5, "No walls detected", ha='center', va='center', fontsize=16)
    plt.title("No Wall Segments")
    plt.tight_layout()
    plt.show()

# =============================================================================
# STEP 4: Enhanced Measurement Preparation
# =============================================================================

print("\n" + "="*50)
print("📊 WALL MEASUREMENT ANALYSIS")
print("="*50)

if wall_segments:
    for i, segment in enumerate(wall_segments):
        mask = segment['mask']
        bbox = segment['bbox']
        
        # Calculate measurements
        mask_area_pixels = np.sum(mask)
        bbox_width = bbox[2] - bbox[0]
        bbox_height = bbox[3] - bbox[1]
        bbox_area = bbox_width * bbox_height
        coverage = mask_area_pixels / bbox_area if bbox_area > 0 else 0
        
        print(f"\n🏠 Wall {i+1}:")
        print(f"   YOLO confidence: {segment['yolo_score']:.3f}")
        print(f"   SAM score: {segment['sam_score']:.3f}")
        print(f"   Bounding box: [{bbox[0]:.0f}, {bbox[1]:.0f}, {bbox[2]:.0f}, {bbox[3]:.0f}]")
        print(f"   Dimensions: {bbox_width:.0f} × {bbox_height:.0f} pixels")
        print(f"   Mask area: {mask_area_pixels:,} pixels")
        print(f"   Coverage: {coverage:.1%}")
        
        # ENHANCED: Better quality assessment
        if segment['yolo_score'] > 0.7 and segment['sam_score'] > 0.9:
            quality = "🏆 EXCELLENT"
        elif segment['yolo_score'] > 0.5 and segment['sam_score'] > 0.8:
            quality = "✅ GOOD"
        elif segment['yolo_score'] > 0.3 and segment['sam_score'] > 0.6:
            quality = "⚠️  MODERATE"
        else:
            quality = "❌ POOR"
        
        print(f"   Quality: {quality}")
    
    print(f"\n🎉 SUCCESS! Your pipeline detected and segmented {len(wall_segments)} walls!")
    print("📏 Ready for the next step: DEPTH ESTIMATION")
    
else:
    print("❌ No walls detected!")
    print("💡 Try:")
    print("   - Lower confidence threshold (conf=0.1)")
    print("   - Different test image with clear walls")
    print("   - Check if image has good lighting")

# =============================================================================
# STEP 5: Pipeline Status and Next Steps
# =============================================================================

print(f"\n🚀 PIPELINE STATUS:")
print("="*40)
print("✅ 1. YOLOv8 wall detection - WORKING")
print("✅ 2. SAM segmentation - WORKING")
print("🔄 3. Next: Add depth estimation")
print("🔄 4. Next: Camera calibration") 
print("🔄 5. Next: Real-world measurements")

if wall_segments:
    print(f"\n🎉 EXCELLENT PROGRESS!")
    print("Your pipeline is working correctly! Wall detection and segmentation are successful.")
    print("Ready to proceed with depth estimation for measurement calculation.")
    
    # ADDED: Preparation for next steps
    print(f"\n📋 Ready for Integration:")
    print(f"   - Wall segments: {len(wall_segments)} found")
    print(f"   - Average YOLO confidence: {np.mean([s['yolo_score'] for s in wall_segments]):.3f}")
    print(f"   - Average SAM score: {np.mean([s['sam_score'] for s in wall_segments]):.3f}")
    print(f"   - Image resolution: {image.shape[1]}x{image.shape[0]}")
    
else:
    print(f"\n⚠️  Need to improve wall detection before proceeding.")