# Exploring the lab42/cov-segm-v3 Dataset

This notebook explores the `lab42/cov-segm-v3` dataset from Hugging Face.
We will load a few samples, visualize the primary image, parse the `conversations` field,
and display the associated mask images, both raw and overlaid on the primary image.

In [1]:
import json

import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset

## Load Dataset

Load the validation split, which is smaller and suitable for exploration.
We'll examine the first few samples.

In [None]:
# Load the dataset
# Using streaming=True initially can be faster if you only need a few samples,
# but for simplicity, we load the first few directly here.
DATASET_NAME = "lab42/cov-segm-v3"
NUM_SAMPLES_TO_LOAD = 3

try:
    # Load specific samples directly if not streaming
    dset = load_dataset(DATASET_NAME, split=f"validation[:{NUM_SAMPLES_TO_LOAD}]")
    print(f"Loaded {len(dset)} samples from {DATASET_NAME}")
except Exception as e:
    print(f"Failed to load dataset: {e}")
    # Fallback or exit if loading fails
    dset = None

In [None]:
dset[0]

## Explore Samples

Iterate through the loaded samples, displaying their contents.

In [10]:
def parse_conversations(sample):
    conversations_str = sample.get("conversations")
    if conversations_str:
        try:
            parsed_conversations = json.loads(conversations_str)
            print("Parsed Conversations:")
            # Pretty print the first conversation element for brevity if it exists
            if parsed_conversations:
                print(json.dumps(parsed_conversations[0], indent=2))
                if len(parsed_conversations) > 1:
                    print(
                        f"  ... and {len(parsed_conversations) - 1} more conversation elements."
                    )
            else:
                print("  (Conversations list is empty)")

            # --- Detailed Parsing (Optional Step 7 from Plan) ---
            # This section demonstrates extracting info for a specific prompt
            print("\n--- Example: Parsing details for each prompt ---")
            for i, conv_item in enumerate(parsed_conversations):
                prompts = [p.get("text", "N/A") for p in conv_item.get("phrases", [])]
                prompt_text = "; ".join(prompts)
                mask_info = conv_item.get("mask")  # Check if 'mask' exists directly

                if mask_info:
                    mask_col = mask_info.get("column", "N/A")
                    mask_val = mask_info.get("positive_value", "N/A")
                    print(
                        f"  Prompt {i}: '{prompt_text}' -> Mask Column: '{mask_col}', Positive Value: {mask_val}"
                    )
                else:
                    # Might be instance masks or no mask (negative example)
                    instance_masks = conv_item.get("instance_masks", [])
                    if instance_masks:
                        # Handle instance masks if needed - structure is slightly different
                        im_details = []
                        for im in instance_masks:
                            im_col = im.get("column", "N/A")
                            im_val = im.get("positive_value", "N/A")
                            im_details.append(f"Col: {im_col}, Val: {im_val}")
                        print(
                            f"  Prompt {i}: '{prompt_text}' -> Instance Masks: [{'; '.join(im_details)}]"
                        )
                    else:
                        print(
                            f"  Prompt {i}: '{prompt_text}' -> No direct 'mask' or 'instance_masks' field found."
                        )
            # --- End Optional Step ---

        except json.JSONDecodeError:
            print("Error parsing conversations JSON.")
            parsed_conversations = None
    else:
        print("conversations field not found or is empty.")
        parsed_conversations = None

In [16]:
def visualize_masks(sample, idx):
    mask_cols_to_check = [f"mask_{i}" for i in range(3)]  # mask_0, mask_1, mask_2

    for col_name in mask_cols_to_check:
        mask_image = sample.get(col_name)
        if mask_image:
            mask_np = np.array(mask_image)

            fig, axes = plt.subplots(1, 2, figsize=(12, 6))

            # Display raw mask
            im0 = axes[0].imshow(mask_np, cmap="viridis")  # Use a distinct colormap
            axes[0].set_title(f"Sample {idx}: Raw {col_name}")
            axes[0].axis("off")
            fig.colorbar(
                im0, ax=axes[0], fraction=0.046, pad=0.04
            )  # Add colorbar to show values

            # Display overlay
            if image_np is not None:
                axes[1].imshow(image_np)
                axes[1].imshow(mask_np, cmap="viridis", alpha=0.5)  # Overlay with transparency
                axes[1].set_title(f"Sample {idx}: {col_name} Overlay on image_0")
                axes[1].axis("off")
            else:
                axes[1].set_title(f"Sample {idx}: Overlay (image_0 missing)")
                axes[1].axis("off")

            plt.tight_layout()
            plt.show()
        # else:
        #     print(f"{col_name} not found or is None in Sample {idx}.") # Reduce verbosity

    # Handle masks_rest (sequence of masks)
    masks_rest_list = sample.get("masks_rest")
    if masks_rest_list and isinstance(masks_rest_list, list) and len(masks_rest_list) > 0:
        print(f"--- Visualizing masks_rest (found {len(masks_rest_list)}) ---")
        for rest_idx, rest_mask_image in enumerate(masks_rest_list):
            if rest_mask_image:
                mask_np = np.array(rest_mask_image)
                col_name = f"masks_rest[{rest_idx}]"

                fig, axes = plt.subplots(1, 2, figsize=(12, 6))

                # Display raw mask
                im0 = axes[0].imshow(mask_np, cmap="viridis")
                axes[0].set_title(f"Sample {idx}: Raw {col_name}")
                axes[0].axis("off")
                fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

                # Display overlay
                if image_np is not None:
                    axes[1].imshow(image_np)
                    axes[1].imshow(mask_np, cmap="viridis", alpha=0.5)
                    axes[1].set_title(f"Sample {idx}: {col_name} Overlay on image_0")
                    axes[1].axis("off")
                else:
                    axes[1].set_title(f"Sample {idx}: Overlay (image_0 missing)")
                    axes[1].axis("off")

                plt.tight_layout()
                plt.show()

    print(f"--- Finished Sample {idx} ---")


In [None]:
if dset:
    for idx, sample in enumerate(dset):
        print(f"--- Sample Index: {idx}, ID: {sample['id']} ---")
        print(f"Number of Images reported: {sample.get('n_images', 'N/A')}")
        print(f"Number of Masks reported: {sample.get('n_masks', 'N/A')}")
        print(f"Number of Conversations reported: {sample.get('n_conversations', 'N/A')}")
        print(f"Dataset Source: {sample.get('dataset', 'N/A')}")
        print(f"Split: {sample.get('split', 'N/A')}")

        primary_image = sample.get("image_0")
        if primary_image:
            plt.figure(figsize=(6, 6))
            plt.imshow(primary_image)
            plt.title(f"Sample {idx}: image_0")
            plt.axis("off")
            plt.show()
            # Convert to numpy for potential overlay later
            image_np = np.array(primary_image)
        else:
            print("image_0 not found in this sample.")
            image_np = None  # Ensure image_np is defined

        # --- Inspect the sample structure ---
        print(f"\nAvailable keys in sample: {list(sample.keys())}")
        if 'image_0' in sample:
            # Check the type to understand how the image is stored
            print(f"Type of 'image_0' field: {type(sample['image_0'])}")
            # If it's bytes or dict, you might need specific decoding
            if isinstance(sample['image_0'], dict) and 'bytes' in sample['image_0']:
                 print("   -> 'image_0' seems to be a dict containing bytes.")
            elif isinstance(sample['image_0'], bytes):
                 print("   -> 'image_0' seems to be raw bytes.")
        else:
            print("\n'image_0' key not found in sample.")
        # --- End inspection ---


        # %% [markdown]
        """
        ### Image Visualization

        Display the primary image (`image_0`). Other `image_*` columns might exist
        for different camera views or data types, but `image_0` is typically the main one.
        We might need to decode it first based on its type printed above.
        """

        # %%
        image_data = sample.get("image_0") # Get the data first
        primary_image_pil = None # Initialize PIL image variable

        if image_data:
            # Attempt to load image data into a PIL Image object
            try:
                if isinstance(image_data, Image.Image):
                    # Already a PIL image
                    primary_image_pil = image_data
                elif isinstance(image_data, dict) and 'bytes' in image_data and image_data['bytes']:
                    # If it's a dict with bytes key
                    primary_image_pil = Image.open(io.BytesIO(image_data['bytes']))
                elif isinstance(image_data, bytes):
                     # If it's raw bytes
                     primary_image_pil = Image.open(io.BytesIO(image_data))
                else:
                    print(f"Cannot automatically display image_0 of type: {type(image_data)}. Manual handling needed.")

                if primary_image_pil:
                    plt.figure(figsize=(6, 6))
                    plt.imshow(primary_image_pil)
                    plt.title(f"Sample {idx}: image_0")
                    plt.axis("off")
                    plt.show()
                    # Convert to numpy for potential overlay later
                    image_np = np.array(primary_image_pil)
                else:
                     image_np = None # Ensure image_np is defined if loading failed

            except Exception as e:
                print(f"Error processing or displaying image_0: {e}")
                image_np = None # Ensure image_np is defined on error
        else:
            print("image_0 data not found or is empty in this sample.")
            image_np = None  # Ensure image_np is defined

        parse_conversations(sample)
        


        ### Basic Metadata
        

        ### Image Visualization

        Display the primary image (`image_0`). Other `image_*` columns might exist
        for different camera views or data types, but `image_0` is typically the main one.
        

        ### Parse `conversations` Field

        The `conversations` field is a JSON string containing a list of dictionaries.
        Each dictionary links a text prompt to its corresponding mask information (which mask column and which pixel value).
        

        ### Mask Visualization

        Display the raw mask images (`mask_0`, `mask_1`, `mask_2`, `masks_rest`)
        and overlay them onto `image_0`. Note that these are the 'packed' masks.
        To see the mask for a *specific* prompt, you need to parse the `conversations`
        field (as shown above) to find the correct column and pixel value, then filter the NumPy array.
        

End of exploration for the first few samples. You can increase `NUM_SAMPLES_TO_LOAD`
or use different indices in `.select()` to explore further. Remember that visualizing
the mask for a specific *prompt* requires parsing `conversations` to get the mask
column and `positive_value`, then filtering the corresponding mask NumPy array.