# 04 - Demo de Inferencia Interactivo

Demo interactivo para probar el modelo con diferentes configuraciones.

## Contenido
1. Cargar modelo
2. Inferencia en imagen con widgets
3. Comparar formatos (PyTorch vs TensorRT)
4. Procesar video

In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output

PROJECT_ROOT = Path('..').resolve()
sys.path.insert(0, str(PROJECT_ROOT))

from ultralytics import YOLO

MODELS_DIR = PROJECT_ROOT / 'models'
DATA_DIR = PROJECT_ROOT / 'data'

print("VR Pillar Detector - Demo Interactivo")
print("="*40)

## 1. Cargar Modelos Disponibles

In [None]:
# Encontrar todos los modelos
available_models = {}

for ext, name in [('.engine', 'TensorRT'), ('.pt', 'PyTorch'), ('.onnx', 'ONNX')]:
    models = list(MODELS_DIR.glob(f'*{ext}'))
    if models:
        model_path = models[0]
        available_models[name] = model_path
        size_mb = model_path.stat().st_size / (1024*1024)
        print(f"{name}: {model_path.name} ({size_mb:.1f} MB)")

# Cargar modelo por defecto (TensorRT si existe, sino PyTorch)
default_format = 'TensorRT' if 'TensorRT' in available_models else 'PyTorch'
model = YOLO(str(available_models[default_format]))
print(f"\nModelo cargado: {default_format}")

## 2. Inferencia Interactiva en Imagen

In [None]:
# Obtener imágenes de ejemplo
sample_images = list((DATA_DIR / 'dataset' / 'val' / 'images').glob('*.jpg'))[:20]
image_names = [img.name for img in sample_images]

# Widgets
image_dropdown = widgets.Dropdown(
    options=image_names,
    value=image_names[0] if image_names else None,
    description='Imagen:'
)

conf_slider = widgets.FloatSlider(
    value=0.65,
    min=0.1,
    max=1.0,
    step=0.05,
    description='Confianza:'
)

iou_slider = widgets.FloatSlider(
    value=0.45,
    min=0.1,
    max=1.0,
    step=0.05,
    description='IoU:'
)

model_dropdown = widgets.Dropdown(
    options=list(available_models.keys()),
    value=default_format,
    description='Modelo:'
)

output = widgets.Output()

def run_inference(image_name, conf, iou, model_format):
    global model
    
    with output:
        clear_output(wait=True)
        
        # Cargar modelo si cambió
        if model_format in available_models:
            model = YOLO(str(available_models[model_format]))
        
        # Encontrar imagen
        img_path = DATA_DIR / 'dataset' / 'val' / 'images' / image_name
        if not img_path.exists():
            print(f"Imagen no encontrada: {img_path}")
            return
        
        # Inferencia
        import time
        start = time.perf_counter()
        results = model(str(img_path), conf=conf, iou=iou, verbose=False)
        elapsed = (time.perf_counter() - start) * 1000
        
        # Mostrar resultado
        annotated = results[0].plot()
        annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
        
        plt.figure(figsize=(14, 8))
        plt.imshow(annotated)
        plt.title(f"{image_name} | {len(results[0].boxes)} detecciones | {elapsed:.1f}ms | {model_format}")
        plt.axis('off')
        plt.tight_layout()
        plt.show()
        
        # Detalles
        if len(results[0].boxes) > 0:
            print(f"\nDetecciones:")
            for i, box in enumerate(results[0].boxes):
                conf_val = box.conf.item()
                print(f"  [{i}] Confianza: {conf_val:.3f}")

# Crear interfaz interactiva
ui = widgets.VBox([
    widgets.HBox([image_dropdown, model_dropdown]),
    widgets.HBox([conf_slider, iou_slider]),
    output
])

# Conectar eventos
def on_change(change):
    run_inference(image_dropdown.value, conf_slider.value, iou_slider.value, model_dropdown.value)

image_dropdown.observe(on_change, names='value')
conf_slider.observe(on_change, names='value')
iou_slider.observe(on_change, names='value')
model_dropdown.observe(on_change, names='value')

display(ui)

# Ejecutar inicial
run_inference(image_dropdown.value, conf_slider.value, iou_slider.value, model_dropdown.value)

## 3. Comparar Velocidad de Formatos

In [None]:
import time
import torch

def benchmark_model(model_path, iterations=50):
    """Benchmark de velocidad."""
    model = YOLO(str(model_path))
    dummy = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
    
    # Warmup
    for _ in range(5):
        model(dummy, verbose=False)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # Benchmark
    times = []
    for _ in range(iterations):
        start = time.perf_counter()
        model(dummy, verbose=False)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        times.append(time.perf_counter() - start)
    
    return np.mean(times) * 1000

# Comparar todos los formatos
print("Benchmark de velocidad (50 iteraciones):")
print("="*40)

results = {}
for name, path in available_models.items():
    try:
        mean_ms = benchmark_model(path)
        fps = 1000 / mean_ms
        results[name] = {'ms': mean_ms, 'fps': fps}
        print(f"{name:12} {mean_ms:>8.2f} ms  ({fps:>6.1f} FPS)")
    except Exception as e:
        print(f"{name:12} Error: {e}")

# Gráfico
if results:
    fig, ax = plt.subplots(figsize=(10, 5))
    names = list(results.keys())
    fps_values = [results[n]['fps'] for n in names]
    
    bars = ax.bar(names, fps_values, color=['#2ecc71', '#3498db', '#e74c3c'][:len(names)])
    ax.set_ylabel('FPS')
    ax.set_title('Comparativa de Velocidad por Formato')
    
    for bar, fps in zip(bars, fps_values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5,
                f'{fps:.0f}', ha='center', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    plt.show()

## 4. Procesar Video (muestra)

In [None]:
# Buscar video de prueba
video_path = DATA_DIR / 'video.mp4'

if video_path.exists():
    print(f"Video encontrado: {video_path}")
    
    # Procesar primeros N frames
    cap = cv2.VideoCapture(str(video_path))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    print(f"FPS: {fps}, Frames totales: {total_frames}")
    
    # Mostrar algunos frames procesados
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    frame_indices = np.linspace(0, total_frames-1, 6, dtype=int)
    
    for ax, frame_idx in zip(axes, frame_indices):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        if ret:
            results = model(frame, conf=0.65, verbose=False)
            annotated = results[0].plot()
            annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
            
            ax.imshow(annotated)
            ax.set_title(f"Frame {frame_idx} - {len(results[0].boxes)} det.")
            ax.axis('off')
    
    cap.release()
    plt.tight_layout()
    plt.show()
else:
    print(f"Video no encontrado en: {video_path}")

## 5. Subir tu propia imagen

In [None]:
# Widget para subir imagen
upload = widgets.FileUpload(
    accept='image/*',
    multiple=False,
    description='Subir imagen'
)

upload_output = widgets.Output()

def on_upload(change):
    with upload_output:
        clear_output(wait=True)
        
        if upload.value:
            # Obtener imagen
            uploaded_file = list(upload.value.values())[0]
            content = uploaded_file['content']
            
            # Convertir a numpy
            import io
            img = Image.open(io.BytesIO(content))
            img_np = np.array(img)
            
            # Inferencia
            results = model(img_np, conf=0.65, verbose=False)
            annotated = results[0].plot()
            
            if len(annotated.shape) == 3 and annotated.shape[2] == 3:
                annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
            
            plt.figure(figsize=(14, 8))
            plt.imshow(annotated)
            plt.title(f"Tu imagen - {len(results[0].boxes)} detecciones")
            plt.axis('off')
            plt.show()

upload.observe(on_upload, names='value')

display(widgets.VBox([upload, upload_output]))