In [None]:
"""
# Exploring the lab42/cov-segm-v3 Dataset with `dataops`

This notebook uses the `src.dataops` module to load and visualize samples
from the `lab42/cov-segm-v3` dataset.
"""

In [None]:
import logging
import os  # Added for environ check
import sys
from pathlib import Path

In [None]:
# Add src directory to path for imports - Adjust based on notebook location
# Assumes notebook is run from project root ($HOME/vibe/vibe_coding)
project_root = (
    Path(os.environ.get("HOME", "/Users/weimu")) / "Development/vibe/vibe_coding"
)  # Use env var or default
src_path = project_root
if str(src_path) not in sys.path:
    print(f"Appending to sys.path: {src_path}")
    sys.path.append(str(src_path))
else:
    print(f"{src_path} already in sys.path")

In [None]:
try:
    import datasets
    import matplotlib.pyplot as plt  # To prevent errors if no plots shown
    from PIL import Image

    from dataops.cov_segm.loader import load_sample
    from dataops.cov_segm.visualizer import visualize_prompt_masks

    print("Imports successful.")
except ImportError as e:
    print(f"Error importing modules: {e}")
    print(f"PYTHONPATH: {os.environ.get('PYTHONPATH')}")
    print(f"sys.path: {sys.path}")
    print("Ensure 'datasets', 'Pillow', 'matplotlib' are installed in the correct environment")
    print("and the 'src' directory is accessible.")
    # Optionally, raise the error or exit
    raise

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

In [None]:
DATASET_NAME = "lab42/cov-segm-v3"
SPLIT = "validation"
NUM_SAMPLES_TO_LOAD = 20  # Reduced for faster testing

In [None]:
logger.info(f"Loading dataset: {DATASET_NAME}, split: {SPLIT}, samples: {NUM_SAMPLES_TO_LOAD}")
try:
    # Use streaming=True for potentially faster initial load if needed later,
    # but loading directly is simpler for small numbers.
    dset = datasets.load_dataset(
        DATASET_NAME,
        split=f"{SPLIT}[:{NUM_SAMPLES_TO_LOAD}]",
        # trust_remote_code=True # Might be needed depending on dataset/HF version
    )
    logger.info(f"Loaded {len(dset)} samples.")
except Exception as e:
    logger.error(f"Failed to load dataset: {e}", exc_info=True)
    dset = None
    # You might want to display the error in the notebook cell output too
    print(f"ERROR loading dataset: {e}")

In [None]:
if dset:
    logger.info("Processing and visualizing samples...")
    # Ensure matplotlib backend is suitable for non-interactive saving or inline display
    %matplotlib inline
    plt.ioff()  # Turn off interactive mode initially

    for i, sample_dict in enumerate(dset):
        sample_id = sample_dict.get("id", f"index_{i}")
        logger.info(f"--- Processing Sample {i} (ID: {sample_id}) ---")
        print(f"--- Processing Sample {i} (ID: {sample_id}) ---")  # Also print to notebook output
        try:
            processed_sample = load_sample(sample_dict)
            logger.info(f"Successfully loaded data for sample {i}.")

            # Visualize based on prompts
            visualized_count = 0
            if processed_sample and "processed_conversations" in processed_sample:
                for conv_idx, conv_item in enumerate(processed_sample["processed_conversations"]):
                    if conv_item.get("phrases"):
                        # Use first phrase text as the prompt title
                        prompt_text = conv_item["phrases"][0]["text"]
                        if prompt_text != "object":
                            continue

                        logger.info(
                            f"  Visualizing for conversation {conv_idx}, prompt: '{prompt_text}'"
                        )
                        print(f"  Visualizing prompt: '{prompt_text}'")

                        try:
                            # Call the visualizer function
                            fig = visualize_prompt_masks(
                                processed_sample,
                                prompt=prompt_text)
                            if fig:
                                plt.figure(fig.number)  # Ensure we're using the figure returned
                                plt.show()  # Display the plot inline in the notebook
                                visualized_count += 1
                            else:
                                logger.warning(
                                    f"  Visualization skipped for prompt '{prompt_text}' (no masks found or error)."
                                )
                                print(
                                    f"  Visualization skipped for prompt '{prompt_text}' (no masks found or error)."
                                )

                        except Exception as vis_e:
                            logger.error(
                                f"  Error visualizing prompt '{prompt_text}' for sample {i}: {vis_e}",
                                exc_info=True,
                            )
                            print(f"  ERROR visualizing prompt '{prompt_text}': {vis_e}")
                    else:
                        logger.warning(
                            f"  Skipping conversation {conv_idx} in sample {i} as it has no phrases."
                        )
                        print(f"  Skipping conversation {conv_idx} (no phrases).")

            if visualized_count == 0:
                logger.warning(f"No visualizations generated for sample {i}.")
                print(f"No visualizations generated for sample {i}.")

        except Exception as load_e:
            logger.error(
                f"Failed to load or process sample {i} (ID: {sample_id}): {load_e}", exc_info=True
            )
            print(f"ERROR loading/processing sample {i} (ID: {sample_id}): {load_e}")

    plt.ion()  # Turn interactive mode back on if needed at the end
else:
    logger.warning("Dataset not loaded. Skipping visualization.")
    print("Dataset not loaded. Skipping visualization.")

In [None]:
logger.info("Notebook execution finished.")
print("Notebook execution finished.")