In [0]:
# Install Zia Neurolabs SDK Python  
!pip install --upgrade zia-sdk-python[databricks] 
# Use below if you're installing updates to the package 
# %restart_python 
# dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
%python

# Base 
from typing import List, Any
import datetime

# Import data handling dependencies 
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Import zia-sdk depdendencies 
from neurolabszia import Zia, NLIRResult 
from neurolabszia.utils import ir_results_to_dataframe, get_spark_schema_from_dataframe, to_spark_dataframe

# 1. Get API key securely from Databricks secrets 
try:
    api_key = dbutils.secrets.get(scope="neurolabs-api", key="demo_abi")
except Exception as e:
    raise RuntimeError("Failed to retrieve API key from Databricks secrets. Make sure the secret scope and key are set up.") from e

# Helper method to get paginated results using Neurolabs SDK 
async def get_paginated_results(
    client: Zia, task_uuid: str, batch_size: int = 10, max_iter: int = 5
) -> list[NLIRResult]:
    """
    Get all results from a task using pagination.

    Args:
        client: Zia client instance
        task_uuid: The UUID of the task
        batch_size: Number of results to fetch per request (default: 10)
        max_iter: Maximum number of batches to fetch results for (default: 5)
    Returns:
        List of all NLIRResult objects
    """
    all_results = []
    offset = 0

    print(f"🔍 Fetching paginated results for task: {task_uuid}")
    print(f"📦 Batch size: {batch_size}")

    while len(all_results) < (batch_size * max_iter):
        print(f"\n📄 Fetching batch at offset {offset}...")

        # Get a batch of results 
        batch = await client.result_management.get_task_results(
            task_uuid=task_uuid, limit=batch_size, offset=offset
        )

        if not batch:
            print(f"✅ No more results found at offset {offset}")
            break

        print(f"✅ Retrieved {len(batch)} results")
        all_results.extend(batch)

        # If we got fewer results than requested, we've reached the end
        if len(batch) < batch_size:  
            print(f"✅ Reached end of results (got {len(batch)} < {batch_size})")
            break

        offset += batch_size

    print(f"\n🎉 Total results retrieved: {len(all_results)}")
    return all_results

In [0]:
# Databricks redacts secrets by default, print the first and last 4 characters of the API key to make sure you've got what's required 
print(f"API Key: {api_key[:4]}{'*' * (len(api_key) - 8)}{api_key[-4:]}")

API Key: eyJh*********************************************************************************************************************************************************************************************************************************************************************************************-RdI


In [0]:
# Run `get_paginated_results` when you want to get new data via Neurolabs API 
task_uuid = "ec01e5c7-51c9-4889-8136-19a4ab7168c1"

async def get_all(max_iter):
    """Main function demonstrating paginated results usage."""

    print("🚀 Zia SDK - Pull Batched IR Results - Example")
    print("=" * 60)
    all_results = []
    # Initialize client once and reuse it
    async with Zia(api_key) as client:
        # Example 1: Get all results with pagination
        print("\n1️⃣ Getting all paginated results:")
        try:
            all_results = await get_paginated_results(client, task_uuid, batch_size=20, max_iter=max_iter)
            print(f"✅ Successfully retrieved {len(all_results)} total results")
        except Exception as e:
            print(f"❌ Error getting all results: {e}")

    return all_results 
    print("\n" + "=" * 60)
    print("🎉 Paginated results example complete!")


#all_results = []
#for results in results_path: 
#    data = load_json(results)
    # Parse into our NLB
    # results = [NLIRResult.model_validate(result) for result in data["items"]]
    # all_results.extend(results)

#print(f"Total results retrieved: {len(all_results)}")

In [0]:
all_results = await get_all(2)

🚀 Zia SDK - Pull Batched IR Results - Example

1️⃣ Getting all paginated results:
🔍 Fetching paginated results for task: ec01e5c7-51c9-4889-8136-19a4ab7168c1
📦 Batch size: 20

📄 Fetching batch at offset 0...
✅ Retrieved 20 results

📄 Fetching batch at offset 20...
✅ Retrieved 20 results

🎉 Total results retrieved: 40
✅ Successfully retrieved 40 total results


In [0]:
len(all_results)

40

In [0]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from datetime import datetime

# Convert IR Results into Spark Dataframe and Upload to Unity Catalog 

print("Execution started at:", datetime.now())

# 1. Create Spark session
spark = SparkSession.builder.appName("NLIRResultsIngestion").getOrCreate()

# 2. Create Catalog, Schema & Table 
catalog_name = "catalog_integration"
schema_name = "neurolabs_ir_results_demo" 
table_name = "ir_results_sample"

# 3. Convert NLIRResults -> pd.Dataframe -> Spark Dataframe
#pdf = ir_results_to_dataframe(all_results)
#ir_results_schema = get_spark_schema_from_dataframe(pdf)
df_spark = to_spark_dataframe(all_results, spark)
# df_spark.head(2)

# 4. Write to Databricks Delta table
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog_name}.{schema_name}")

df_spark.write.format("delta").mode("overwrite").saveAsTable(f"{catalog_name}.{schema_name}.{table_name}")

print(f"Successfully wrote {df_spark.count()} records to table {schema_name}.{table_name}.") 

Execution started at: 2025-09-02 21:55:22.010910
Successfully wrote 1480 records to table neurolabs_ir_results_demo.ir_results_sample.


# Deprecated - helpers to facilitate IRREsults -> Dataframe -> Spark conversion now in zia.utils 

In [0]:
# Used for debugging 
pdf = ir_results_to_dataframe(all_results)
pdf.head(2)
# Replace NaN values with None for Spark compatibility
pdf = pdf.where(pd.notnull(pdf), None)
#schema = get_dynamic_spark_schema(pdf)
schema = []

# 1. Create Spark session
spark_session = SparkSession.builder.appName("NLIRResultsIngestion").getOrCreate()

# Convert to Spark DataFrame using records
df_spark = spark_session.createDataFrame(pdf, schema=schema)
# 4. Write to Databricks Delta table
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog_name}.{schema_name}")

df_spark.write.format("delta").mode("overwrite").saveAsTable(f"{catalog_name}.{schema_name}.{table_name}")

print(f"Successfully wrote {df_spark.count()} records to table {schema_name}.{table_name}.") 



In [0]:
# Helpers to call API, get paginated IRResults & connvert NLIRResults into a Dataframe 
# TODO: Include the NLIRResults -> Spark Dataframe in the zia/utils in the next iteration of the SDK 

def get_dynamic_spark_schema(df: pd.DataFrame) -> 'StructType':
    """
    Dynamically generate a Spark schema based on the actual DataFrame structure.
    
    This function analyzes the pandas DataFrame and creates a matching Spark schema,
    ensuring no mismatches between the DataFrame columns and schema fields.
    
    Args:
        df: pandas DataFrame created by ir_results_to_dataframe()
        
    Returns:
        pyspark.sql.types.StructType schema that matches the DataFrame exactly
    """
    from pyspark.sql.types import (
        StructType, FloatType, TimestampType, BooleanType,
        ArrayType, StructType as SparkStructType
    )
    
    fields = []
    
    for column_name, dtype in df.dtypes.items():
        # Handle different pandas dtypes
        if dtype == 'object':
            # Check if it's a datetime column
            if column_name in ['result_created_at', 'result_updated_at']:
                spark_type = TimestampType()
            # Check if it's the alternative_predictions column (list of dicts)
            elif column_name == 'alternative_predictions':
                # Define schema for alternative prediction items
                alt_pred_schema = SparkStructType([
                    StructField("category_id", IntegerType(), True),
                    StructField("category_name", StringType(), True),
                    StructField("score", FloatType(), True),
                ])
                spark_type = ArrayType(alt_pred_schema)
            else:
                spark_type = StringType()
        elif dtype == 'int64':
            spark_type = IntegerType()
        elif dtype == 'float64':
            spark_type = FloatType()
        elif dtype == 'bool':
            spark_type = BooleanType()
        elif dtype == 'datetime64[ns]':
            spark_type = TimestampType()
        else:
            # Default to string for unknown types
            spark_type = StringType()
        
        fields.append(StructField(column_name, spark_type, True))
    
    return StructType(fields)


def get_spark_schema_from_dataframe(df: pd.DataFrame) -> 'StructType':
    """
    Generate Spark schema directly from the DataFrame structure.
    
    This is the recommended approach to ensure perfect schema matching.
    
    Args:
        df: pandas DataFrame created by ir_results_to_dataframe()
        
    Returns:
        pyspark.sql.types.StructType schema that matches the DataFrame exactly
    """
    return get_dynamic_spark_schema(df)

def ir_results_to_dataframe(
    results: List[Any],
    include_bbox: bool = True,
    include_alternative_predictions: bool = True) -> pd.DataFrame:
    """
    Convert a list of NLIRResult objects to a pandas DataFrame.
    
    This function matches categories with annotations using the category_id
    and creates a flat DataFrame with all attributes for each detected item.
    
    Args:
        results: List of NLIRResult objects (from zia.models)
        include_bbox: Whether to include bounding box coordinates as separate columns
        include_alternative_predictions: Whether to include alternative predictions
        
    Returns:
        pandas DataFrame with one row per detected item
        
    Example:
        >>> from zia import ir_results_to_dataframe
        >>> results = await client.image_recognition.get_all_task_results(task_uuid)
        >>> df = ir_results_to_dataframe(results)
        >>> print(df.head())
    """
    rows = []

    for result in results:
        if not result.coco or result.status.value != "PROCESSED":
            continue

        # Create a mapping of category_id to category for quick lookup
        category_map = {cat.id: cat for cat in result.coco.categories}

        for annotation in result.coco.annotations:
            # Get the corresponding category
            category = category_map.get(annotation.category_id)
            if not category:
                continue

            # Base row with result-level information
            row = {
                # Result-level information
                "result_uuid": result.uuid,
                "task_uuid": result.task_uuid,
                "image_url": result.image_url,
                "result_status": result.status.value,
                "result_duration": result.duration,
                "result_created_at": result.created_at,
                "result_updated_at": result.updated_at,
                "confidence_score": result.confidence_score,
                # Image information
                "image_id": annotation.image_id,
                "image_width": next(
                    (
                        img.width
                        for img in result.coco.images
                        if img.id == annotation.image_id
                    ),
                    None,
                ),
                "image_height": next(
                    (
                        img.height
                        for img in result.coco.images
                        if img.id == annotation.image_id
                    ),
                    None,
                ),
                "image_filename": next(
                    (
                        img.file_name
                        for img in result.coco.images
                        if img.id == annotation.image_id
                    ),
                    None,
                ),
                # Annotation information
                "annotation_id": annotation.id,
                "category_id": annotation.category_id,
                "area": annotation.area,
                "iscrowd": annotation.iscrowd,
                "detection_score": annotation.neurolabs.score,
                # Category information
                "category_name": category.name,
                "category_supercategory": category.supercategory,
            }

            # Add bounding box coordinates if requested
            if include_bbox and annotation.bbox:
                row.update(
                    {
                        "bbox_x": annotation.bbox[0],
                        "bbox_y": annotation.bbox[1],
                        "bbox_width": annotation.bbox[2],
                        "bbox_height": annotation.bbox[3],
                    }
                )

            # Add Neurolabs category information
            if category.neurolabs:
                row.update(
                    {
                        "product_uuid": category.neurolabs.productUuid,
                        "product_name": category.neurolabs.name,
                        "product_brand": category.neurolabs.brand,
                        "product_barcode": category.neurolabs.barcode,
                        "product_custom_id": category.neurolabs.customId,
                        "product_label": category.neurolabs.label,
                    }
                )

            # Add alternative predictions if requested
            if (
                include_alternative_predictions
                and annotation.neurolabs.alternative_predictions
            ):
                alt_predictions = []
                for alt_pred in annotation.neurolabs.alternative_predictions:
                    alt_category = category_map.get(alt_pred.category_id)
                    alt_predictions.append(
                        {
                            "category_id": alt_pred.category_id,
                            "category_name": alt_category.name
                            if alt_category
                            else f"Unknown_{alt_pred.category_id}",
                            "score": alt_pred.score,
                        }
                    )
                row["alternative_predictions"] = alt_predictions

            # Add modalities if present
            if annotation.neurolabs.modalities:
                for (
                    modality_name,
                    modality_value,
                ) in annotation.neurolabs.modalities.items():
                    row[f"modality_{modality_name}"] = modality_value

            rows.append(row)

    if not rows:
        return pd.DataFrame()

    df = pd.DataFrame(rows)

    # Convert datetime columns
    datetime_columns = ["result_created_at", "result_updated_at"]
    for col in datetime_columns:
        if col in df.columns:
            df[col] = pd.to_datetime(df[col])

    return df


async def get_paginated_results(client: Zia, task_uuid: str, 
                                batch_size: int = 10, max_offset: int = 100) -> List[NLIRResult]:
    """
    Get all results from a task using pagination.
    
    Args:
        client: Zia client instance
        task_uuid: The UUID of the task
        batch_size: Number of results to fetch per request (default: 10)
        max_offset: Maximum number of images to fetch results for (default: 100)
    Returns:
        List of all NLIRResult objects
    """
    all_results = []
    offset = 0
    
    print(f"🔍 Fetching paginated results for task: {task_uuid}")
    print(f"📦 Batch size: {batch_size}")
    
    while True:
        print(f"\n📄 Fetching batch at offset {offset}...")
        
        # Get a batch of results
        batch = await client.result_management.get_task_results(
            task_uuid=task_uuid,
            limit=batch_size,
            offset=offset
        )
        
        if not batch:
            print(f"✅ No more results found at offset {offset}")
            break
            
        print(f"✅ Retrieved {len(batch)} results")
        all_results.extend(batch)
        
        # If we got fewer results than requested, we've reached the end
        if len(batch) < batch_size or offset >= max_offset:
            print(f"✅ Reached end of results (got {len(batch)} < {batch_size})")
            break
            
        offset += batch_size
    
    print(f"\n🎉 Total results retrieved: {len(all_results)}")
    return all_results


async def get_paginated_results_with_status_filter(
    client: Zia,
    task_uuid: str, 
    status_filter: str = "PROCESSED",
    batch_size: int = 10,
    max_offset: int = 100
) -> List[NLIRResult]:
    """
    Get paginated results with status filtering.
    
    Args:
        client: Zia client instance
        task_uuid: The UUID of the task
        status_filter: Status to filter by (default: "PROCESSED")
        batch_size: Number of results to fetch per request
        
    Returns:
        List of filtered NLIRResult objects
    """
    all_results = []
    offset = 0
    
    print(f"🔍 Fetching {status_filter} results for task: {task_uuid}")
    print(f"📦 Batch size: {batch_size}")
    
    while True:
        print(f"\n📄 Fetching batch at offset {offset}...")
        
        # Get a batch of results
        batch = await client.result_management.get_task_results(
            task_uuid=task_uuid,
            limit=batch_size,
            offset=offset
        )
        
        if not batch:
            print(f"✅ No more results found at offset {offset}")
            break
            
        # Filter by status
        filtered_batch = [result for result in batch if result.status.value == status_filter]
        
        print(f"✅ Retrieved {len(batch)} results, {len(filtered_batch)} with status '{status_filter}'")
        all_results.extend(filtered_batch)
        
        # If we got fewer results than requested, we've reached the end
        if len(batch) < batch_size or offset >= max_offset:
            print(f"✅ Reached end of results (got {len(batch)} < {batch_size})")
            break
            
        offset += batch_size
    
    print(f"\n🎉 Total {status_filter} results: {len(all_results)}")
    return all_results

In [0]:
# 2. Fetch all catalog items with pagination
def fetch_all_catalog_items(api_key, base_url, limit=50):
    headers = {"accept": "application/json", "X-API-Key": api_key}
    offset = 0
    all_items = []
    while True:
        url = f"{base_url}?limit={limit}&offset={offset}"
        response = requests.get(url, headers=headers)
        response.raise_for_status()
        data = response.json()
        items = data.get("items", [])
        if not items:
            break
        all_items.extend(items)
        if len(items) < limit:
            break
        offset += limit
    return all_items

base_url = "https://api.neurolabs.ai/v2/catalog-items"
items = fetch_all_catalog_items(api_key, base_url, limit=10)

# 3. Convert to Spark DataFrame
# TODO - remove the need for converting to pandas

# Initialize Spark session
spark = SparkSession.builder.appName("DataCatalogIngestion").getOrCreate()
pdf = pd.DataFrame(items)
spark_df = spark.createDataFrame(pdf)

# 4. Write to Databricks Delta table
schema = "neurolabs_catalog" 
# Cast string columns to appropriate types
spark_df = spark_df.withColumn("height", col("height").cast("double"))
spark_df = spark_df.withColumn("width", col("width").cast("double"))
spark_df = spark_df.withColumn("depth", col("depth").cast("double"))
spark_df = spark_df.withColumn("created_at", col("created_at").cast("timestamp"))
spark_df = spark_df.withColumn("updated_at", col("updated_at").cast("timestamp"))

# Write data to a managed table in a custom database
table_name = "agbarr_catalog"
print(pdf.head(2))
spark_df.write.format("delta").mode("overwrite").saveAsTable(f"{schema}.{table_name}")

print(f"Successfully wrote {spark_df.count()} records to table {schema}.{table_name}.") 

                                   uuid  ...                  updated_at
0  a264509a-8533-4943-bd48-3f1333ebd052  ...  2025-03-11T13:28:49.695921
1  6b59bb4b-76ff-4159-b5a8-715f96d3f45d  ...  2025-01-23T12:00:59.785723

[2 rows x 16 columns]
Successfully wrote 328 records to table neurolabs_catalog.agbarr_catalog.
