# Multiclass Defect Detection with Distributed training using PyTorch Object Detection Models in Snowflake Notebooks


In [None]:
!pip freeze | grep snow
!pip install opencv-python-headless

session = get_active_session()
import torch
import torchvision
print(torch.__version__)
print(torchvision.__version__)

## Install necessary packages:

* torch
* torchvision
* opencv
* matplotlib
* Pillow

In [None]:
!pip install opencv-python
!apt update && apt install -y libsm6 libxext6
!apt-get install -y libxrender-dev

### Import necessary packages

In [None]:
# Import python packages
import streamlit as st
import pandas as pd
from snowflake.snowpark.context import get_active_session

import os
import sys
import time
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN


import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as img
from snowflake.ml.registry import Registry

from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T

from snowflake.ml.modeling.distributors.pytorch import PyTorchDistributor, PyTorchScalingConfig, WorkerResourceConfig
from snowflake.ml.data.sharded_data_connector import ShardedDataConnector


import warnings
warnings.filterwarnings("ignore")



session.query_tag = {"origin":"sf_sit-is", 
                     "name":"distributed_ml_crt_imageanomaly_detection", 
                     "version":{"major":1, "minor":0,},
                     "attributes":{"is_quickstart":1, "source":"notebook"}}

In [None]:
# The NVIDIA A10G is a professional GPU designed for data center workloads, such as AI inference, virtual desktops
#  (VDI), and professional graphics. It's based on the NVIDIA Ampere architecture and features 24 GB of GDDR6
#  memory, 80 RT Cores, and 320 third-generation Tensor Cores, delivering up to 250 TOPS of compute power for AI. 
# Key features include its 300W power envelope, single-slot form factor, and versatility in handling both 
# graphically intensive and AI-accelerated applications. 

# Get device info
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print("Number of GPU devices available:", num_gpus)
    
    for i in range(num_gpus):
        print("Device", i, ":", torch.cuda.get_device_name(i))
    
    #Set a default device
    torch.cuda.set_device(0)
else:
    print("CUDA is not available. Check your installation or GPU setup.")

In [None]:
#The data_stage contains the PCB images loaded in the /images subfolder and the labels loaded in the /labels subfolder
session.sql("ls @data_stage")

### View the training dataset

In [None]:
session.table("training_data").limit(5).collect();

# # Distributed Model Training
# 
# This section demonstrates how to train a Faster R-CNN object detection model using Snowflake's distributed training capabilities. The training process leverages multiple GPU workers in parallel to accelerate model training on the PCB defect detection dataset.
# 
# ## Step 1: Define a Training Function for Each Worker
# 
# We create a `train_func()` that encapsulates the complete training logic for a single worker. This function will:
# - Initialize the distributed training environment and establish communication between workers
# - Load and preprocess the training data from Snowflake tables
# - Create a custom PyTorch Dataset that decodes base64-encoded images and prepares bounding box annotations
# - Initialize the Faster R-CNN model with a custom classifier head for our specific number of defect classes
# - Wrap the model in DistributedDataParallel (DDP) to enable gradient synchronization across workers
# - Execute the training loop with forward pass, loss calculation, backpropagation, and optimizer updates
# - Save the trained model weights to a Snowflake stage for later inference
# 
# Each worker executes this function independently on its assigned data shard, with PyTorch's distributed training framework coordinating gradient updates across all workers.
# 
# ## Step 2: Execute the Training Function Using PyTorchDistributor
# 
# The `PyTorchDistributor` is Snowflake's orchestration layer that manages the distributed training job. It handles:
# - Provisioning GPU compute resources from the specified compute pool
# - Distributing the training function code to all workers
# - Coordinating the initialization and synchronization of the distributed training process
# - Monitoring worker health and handling failures
# - Collecting results and logs from all workers
# 
# Key components of the distributed training configuration:
# 
# * **ShardedDataConnector**: Automatically partitions the training dataset into non-overlapping shards, with each worker receiving a unique subset of the data. This ensures:
#   - No data duplication across workers (each image is processed by exactly one worker per epoch)
#   - Balanced workload distribution for optimal GPU utilization
#   - Efficient data loading directly from Snowflake tables without manual partitioning logic
# 
# * **PyTorchScalingConfig**: Defines the distributed training cluster topology and resource allocation:
#   - `num_workers`: Number of parallel training processes (typically matches the number of available GPUs)
#   - `num_cpus_per_worker`: CPU cores allocated to each worker for data preprocessing and loading
#   - `num_gpus_per_worker`: GPUs assigned to each worker (usually 1 GPU per worker for optimal performance)
#   - `memory_per_worker`: RAM allocated per worker for caching data and model states
#   - These settings directly impact training speed, memory usage, and cost


In [None]:
# ==============================================================================
# IMPORTS: Required libraries for distributed training
# ==============================================================================
import base64  # For decoding base64-encoded image data from Snowflake
import io  # For handling in-memory byte streams
import cv2  # OpenCV for advanced image processing (if needed)
import torch  # PyTorch deep learning framework
import numpy as np  # Numerical computing library
from torch.utils.data import Dataset, DataLoader, IterableDataset  # Data loading utilities
from PIL import Image  # Python Imaging Library for image manipulation
import torchvision  # Computer vision utilities and models
from torchvision.models.detection import fasterrcnn_resnet50_fpn  # Pretrained Faster R-CNN model
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor  # Classifier head for Faster R-CNN
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights  # Pretrained weights
import torch.distributed as dist  # PyTorch distributed training utilities
from snowflake.ml.modeling.distributors.pytorch import get_context  # Snowflake distributed context
from torch.nn.parallel import DistributedDataParallel as DDP  # Distributed Data Parallel wrapper
import tempfile  # For creating temporary files/directories
import cloudpickle as cp  # Enhanced pickling for Python objects

# ==============================================================================
# MAIN TRAINING FUNCTION: Executed by each worker in distributed training
# ==============================================================================
def train_func():
    # ------------------------------------------------------------------------------
    # DISTRIBUTED TRAINING SETUP
    # ------------------------------------------------------------------------------
    # Get the distributed training context from Snowflake
    context = get_context()
    
    # Get this worker's rank (unique ID) in the distributed training cluster
    rank = context.get_rank()
    
    # Initialize the process group for distributed training using NCCL backend
    # NCCL (NVIDIA Collective Communications Library) is optimized for GPU communication
    dist.init_process_group(backend="nccl")
    
    # Print worker information for debugging
    print(f"Worker Rank : {rank}, world_size: {context.get_world_size()}")

    # ==============================================================================
    # CUSTOM DATASET CLASS: Transforms Snowflake data for PyTorch training
    # ==============================================================================
    class FCBData(IterableDataset):
        """
        Custom PyTorch IterableDataset for PCB defect detection.
        Decodes base64 images from Snowflake and prepares targets for Faster R-CNN.
        """
        def __init__(self, source_dataset, transforms=None):  
            # Store reference to the Snowflake dataset shard
            self.source_dataset = source_dataset
            
            # Set transform pipeline; default to ToTensor if none provided
            # ToTensor converts PIL images to PyTorch tensors with shape [C, H, W]
            self.transforms = transforms if transforms else torchvision.transforms.ToTensor()
    
        def __iter__(self):
            """
            Iterator that yields (image, target) pairs for each data row.
            """
            for row in self.source_dataset:
                # --------------------------------------------------------------
                # IMAGE DECODING: Convert base64 string to PIL Image
                # --------------------------------------------------------------
                base64_image = row['IMAGE_DATA']
                # Decode base64 string -> bytes -> PIL Image
                image = Image.open(io.BytesIO(base64.b64decode(base64_image)))
                
                # Apply transformations (converts PIL Image to tensor)
                image = self.transforms(image)
    
                # --------------------------------------------------------------
                # TARGET PREPARATION: Extract bounding boxes and class labels
                # --------------------------------------------------------------
                # Extract bounding box coordinates [xmin, ymin, xmax, ymax]
                # Format: [[xmin, ymin, xmax, ymax]] for a single detection per image
                boxes = [[row[k].item() for k in ["XMIN", "YMIN", "XMAX", "YMAX"]] for _ in range(1)]
                
                # Extract class label (defect type: open, short, mousebite, etc.)
                labels = [row["CLASS"].item()]
    
                # Convert to PyTorch tensors with appropriate dtypes
                boxes = torch.as_tensor(boxes, dtype=torch.float32)  
                labels = torch.as_tensor(labels, dtype=torch.int64)
                
                # --------------------------------------------------------------
                # TARGET DICTIONARY: Format required by Faster R-CNN
                # --------------------------------------------------------------
                target = {  
                    'boxes': boxes,  # Bounding box coordinates [N, 4]
                    'labels': labels,  # Class labels [N]
                    'image_id': torch.tensor([int(row["FILENAME"])]),  # Unique image identifier
                    'area': (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]),  # Box area (width * height)
                    'iscrowd': torch.zeros((boxes.shape[0],), dtype=torch.uint8)  # 0 = single object, 1 = crowd
                }
                
                # Yield (image, target) pair for this sample
                yield (image, target)

    # ------------------------------------------------------------------------------
    # GPU CONTEXT: Set the GPU device for this worker
    # ------------------------------------------------------------------------------
    with torch.cuda.device(rank):
        # ==============================================================================
        # MODEL INITIALIZATION: Load pretrained Faster R-CNN and customize for PCB defects
        # ==============================================================================
        # Load pretrained Faster R-CNN with ResNet50 backbone and Feature Pyramid Network
        weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT  
        model = fasterrcnn_resnet50_fpn(weights=weights)
          
        # ------------------------------------------------------------------------------
        # CLASSIFIER HEAD MODIFICATION: Adapt for custom number of defect classes
        # ------------------------------------------------------------------------------
        # Set number of classes (background + 6 defect types)
        num_classes = 7
        
        # Get the number of input features to the classification head
        in_features = model.roi_heads.box_predictor.cls_score.in_features  
        
        # Replace the pretrained predictor with a new one for our custom classes
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        
        # Move model to the GPU assigned to this worker
        model.to(rank)
        
        # Wrap model with DistributedDataParallel for multi-GPU training
        # DDP synchronizes gradients across all workers during backpropagation
        model = DDP(model, device_ids=[rank])
        
        # ==============================================================================
        # OPTIMIZER AND SCHEDULER: Configure training optimization
        # ==============================================================================
        # Adam optimizer with weight decay for regularization
        # Only optimize parameters that require gradients
        optimizer = torch.optim.Adam(
            [p for p in model.parameters() if p.requires_grad], 
            lr=0.0001,  # Learning rate
            weight_decay=0.0005  # L2 regularization
        )
        
        # Learning rate scheduler: reduce LR by factor of 0.1 every 3 epochs
        # Helps model converge to better local minima
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    
        # ==============================================================================
        # DATA LOADING: Prepare distributed dataset and dataloader
        # ==============================================================================
        # Get the dataset map from Snowflake's distributed context
        dataset_map = context.get_dataset_map()
        
        # Get this worker's shard of the training data
        # ShardedDataConnector automatically partitions data across workers
        train_shard = dataset_map["train"].get_shard().to_torch_dataset()
        
        # Wrap the shard with our custom dataset class
        train_dataset = FCBData(train_shard)
    
        # Get hyperparameters passed from the PyTorchDistributor
        hyper_parms = context.get_hyper_params()
        
        # ------------------------------------------------------------------------------
        # COLLATE FUNCTION: Custom batch collation for variable-sized inputs
        # ------------------------------------------------------------------------------
        def collate_fn(batch):
            """
            Custom collate function for Faster R-CNN.
            Object detection models accept lists of images and targets,
            not batched tensors (since images may have different sizes).
            """
            return tuple(zip(*batch))
    
        # Extract batch size from hyperparameters
        batch_size = int(hyper_parms['batch_size'])
        
        # Create DataLoader for efficient batch loading
        train_data_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=False,  # Shuffling handled by ShardedDataConnector
            collate_fn=collate_fn,  # Use custom collation
            pin_memory=True,  # Pin memory for faster GPU transfer
            pin_memory_device=f"cuda:{rank}"  # Pin to this worker's GPU
        )

        # ==============================================================================
        # TRAINING LOOP: Train the model for specified number of epochs
        # ==============================================================================
        num_epochs = int(hyper_parms['num_epochs'])
        
        for epoch in range(num_epochs):
            # Set model to training mode (enables dropout, batch norm training, etc.)
            model.train()
            
            # Initialize metrics for this epoch
            running_loss = 0.0
            running_batches = 0
            
            # ------------------------------------------------------------------------------
            # BATCH ITERATION: Process each batch of images and targets
            # ------------------------------------------------------------------------------
            for images, targets in train_data_loader:
                running_batches = running_batches + 1
                
                # Normalize images from [0, 255] to [0, 1] range
                # Faster R-CNN expects pixel values in [0, 1]
                images = [image.float() / 255.0 for image in images]
                
                # Transfer images to GPU
                images = [image.to(rank) for image in images]
                
                # Transfer all target tensors to GPU
                targets = [{k: v.to(rank) for k, v in t.items()} for t in targets]
                
                # ----------------------------------------------------------------------
                # FORWARD PASS: Compute losses
                # ----------------------------------------------------------------------
                # In training mode, Faster R-CNN returns a dictionary of losses
                # (classification loss, box regression loss, RPN losses, etc.)
                loss_dict = model(images, targets)
                
                # Sum all individual losses into a single scalar
                losses = sum(loss for loss in loss_dict.values())
                
                # ----------------------------------------------------------------------
                # BACKWARD PASS: Compute gradients and update weights
                # ----------------------------------------------------------------------
                # Clear gradients from previous iteration
                optimizer.zero_grad()
                
                # Compute gradients via backpropagation
                losses.backward()
                
                # Update model parameters using computed gradients
                optimizer.step()
    
                # Accumulate loss for epoch statistics
                running_loss += losses.item()
    
            # Print epoch statistics (loss and images processed)
            print(f"[Rank {rank}] Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / (running_batches*batch_size):.4f}, Processed {running_batches * (epoch+1) * batch_size} images so far")
            
            # Step the learning rate scheduler (reduce LR if scheduled)
            lr_scheduler.step()
    
        # ==============================================================================
        # MODEL SAVING: Save trained model (only from rank 0 to avoid conflicts)
        # ==============================================================================
        MODEL_PATH = "/tmp/models/detectionmodel.pt"
        
        if rank == 0:
            # Only the primary worker (rank 0) saves the model
            # This prevents multiple workers from writing to the same file
            with open(MODEL_PATH, mode="w+b") as model_file:
                # Save model state dict (weights and biases)
                # Use model.module to access the underlying model inside DDP wrapper
                torch.save(model.module.state_dict(), model_file)
            print(f"Model written to {MODEL_PATH}")
    
        # Training completion message
        print(f"[Rank {rank}] Training completed.")

### For the purpose of this quickstart, we have considered a smaller volume as the data source. But ideally this can scale million rows

1. Split the dataset (shard) for distributed training across multiple workers.
2. Train a PyTorch model using 4 workers, each utilizing 1 GPU for efficient computation. Control the training with hyperparameters such as batch size and number of epochs.
3. Adjust the number of epochs as needed

In [None]:
# ==============================================================================
# DISTRIBUTED TRAINING SETUP: Configure PyTorch Distributor for Multi-GPU Training
# ==============================================================================

# Import required components for distributed PyTorch training in Snowflake
from snowflake.ml.modeling.distributors.pytorch import PyTorchDistributor, PyTorchScalingConfig, WorkerResourceConfig  
from snowflake.ml.data.sharded_data_connector import ShardedDataConnector  
import ray, torch

# Initialize Ray cluster if not already running
# Ray is used to orchestrate distributed training across multiple workers
if not ray.is_initialized(): 
    ray.init()

# Display available cluster resources (should show GPU count)
print(ray.cluster_resources())        # should include 'GPU': N

# Verify CUDA availability and GPU count for PyTorch
print(torch.cuda.is_available(), torch.cuda.device_count())

# Load training data from Snowflake table
df = session.table("training_data")

# Create sharded data connector to distribute data across workers
# Each worker will receive a shard (partition) of the training data
train_data = ShardedDataConnector.from_dataframe(df)

# Configure the PyTorch distributor for distributed training
pytorch_trainer = PyTorchDistributor(  
    train_func=train_func,  # The training function defined earlier
    scaling_config=PyTorchScalingConfig(  
        num_nodes=1,  # Number of compute nodes to use
        num_workers_per_node=1,  # 4 workers per node for parallel training
        # Allocate 1 GPU per worker (4 GPUs total for this configuration)
        resource_requirements_per_worker=WorkerResourceConfig(num_cpus=0, num_gpus=1),  
    )  
)  

# Execute distributed training
# - dataset_map: Maps the sharded training data to the "train" key
# - hyper_params: Training hyperparameters (batch size and number of epochs)
pytorch_trainer.run(
    dataset_map={"train": train_data},
    hyper_params={"batch_size": "32", "num_epochs": "100"}
)

# MODEL DEPLOYMENT


# Snowflake Model Registry - Securely manage models and their metadata in Snowflake.

The model registry stores machine learning models as first-class schema-level objects in Snowflake.

* Load the model produced by trainer 
* Define custom wrapper for the PyTorch model
* Save it to Model Registry by specifying the model_name,version_name,input dataframe as signature and conda_dependencies

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights
from PIL import Image
import io
import json
import base64
df=session.table("training_data").limit(1).to_pandas()

first_row = df.iloc[0]  
base64_image = first_row['IMAGE_DATA'] 
df = pd.DataFrame({'IMAGE_DATA': [base64_image]})  

spdf=session.create_dataframe(df)
# Function to load the model
def load_model(model_path):  
    weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT  
    model = fasterrcnn_resnet50_fpn(weights=weights)  
    
    # Modify the box predictor for your specific dataset
    num_classes = 6  # Background + 5 classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features  
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)  
    model.load_state_dict(torch.load(model_path), strict=False)  
    model.double()
    model.eval()  
    return model  

# Function to decode and transform an image
def decode_and_transform_image(base64_image):  
    image_data = base64.b64decode(base64_image)  
    image = Image.open(io.BytesIO(image_data)).convert('RGB')  
    if image.mode != 'RGB':
        image = image.convert('RGB')
    # Define the necessary transformations
    transform = transforms.Compose([  
        transforms.Resize((224, 224)), 
        transforms.ToTensor(),  # Converts to [C, H, W]
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
    ])  
    image_tensor = transform(image)
    image_tensor = image_tensor.double()
    
    # Debugging: Print the shape after transformation
    print(f"Shape after transformation: {image_tensor.shape}")
    
    return image_tensor


# try:
model_path = '/tmp/models/detectionmodel.pt'
model = load_model(model_path)

from snowflake.ml.model import custom_model

class DefectDetectionModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)

    @custom_model.inference_api
    def predict(self, input_df: pd.DataFrame) -> pd.DataFrame:
        processed_input = torch.stack(input_df['IMAGE_DATA'].apply(decode_and_transform_image).to_list())
        raw_output = self.context.model_ref("rcnn").forward(processed_input)
        final_output = pd.DataFrame({"output": [json.dumps({k: v.detach().cpu().numpy().tolist() for k, v in res.items()}) for res in raw_output]})
        return final_output

ddm = DefectDetectionModel(context = custom_model.ModelContext(models={'rcnn': model}))


ml_reg = Registry(session=session)  
# Log the model with the sample input for Snowflake registry
mv = ml_reg.log_model(  
    ddm,  
    model_name="DefectDetectionModel",  
    version_name='v3',  
    sample_input_data=spdf,
    conda_dependencies=["pytorch", "torchvision"],
    options={"embed_local_ml_library": True,
             
                "relax": True}

)
    

## Fetch the logged Model from Snowflake Registry



In [None]:
# Usage Example
reg = Registry(session=session) 
model_ref = reg.show_models()
model_ref

## Detect Defects on Validation dataset
Lets consider there is a validation table VAL_IMAGES_LABELS which contains the Base64 Encoding information of validation images.

* Get a reference to a specific model from the registry by name using the registry’s get_model method
* Get a reference to a specific version of a model as a ModelVersion instance using the model’s version method.
* Carry inference using the model and output the predictions


In [None]:

m = reg.get_model("DEFECTDETECTIONMODEL")
mv = m.version("GENTLE_DONKEY_4")


df=session.table("VAL_IMAGES_LABELS").limit(1).to_pandas()

first_row = df.iloc[0]
base64_image = first_row['IMAGE_DATA'] 
image_data_df = pd.DataFrame({'IMAGE_DATA': [base64_image]})  
image_data_df.head()



remote_prediction = mv.run(image_data_df, function_name="predict")
remote_prediction.head()

Fetch predictions and use a function display_image_with_boxes() to display Image with Bounding Boxes and Labels


In [None]:
import json
import base64
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import io

# Class mapping dictionary
classes_la = {
    0: "open",
    1: "short",
    2: "mousebite",
    3: "spur",
    4: "copper",
    5: "pin-hole"
}

# Function to display the image with bounding boxes and class labels
def display_image_with_boxes(image, boxes, labels, scores, target_size=(800, 600)):
    # Resize the image to a target size
    img = image.resize(target_size).convert("RGB")  # Resize and convert to RGB
    img_np = np.array(img)

    # Adjust the DPI and figure size
    fig, ax = plt.subplots(figsize=(3, 6), dpi=10)  # Adjust figure size and DPI
    ax.imshow(img_np)

    for label, box, score in zip(labels, boxes, scores):
        xmin, ymin, xmax, ymax = box
        class_label = classes_la[label]

        # Create a Rectangle patch
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='r', facecolor='none')
        ax.text(xmin, ymin, f"{class_label}: {score:.2f}", verticalalignment='top', color='red', fontsize=13, weight='bold')
        ax.add_patch(rect)

    plt.axis('off')
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Ensure no padding/margins around the image
    plt.show()

# Combine the image data and remote prediction DataFrames
combined_df = pd.concat([image_data_df, remote_prediction], axis=1)

# Create a list to store data for the final DataFrame
rows = []

# Iterate through each row in the combined DataFrame
for index, row in combined_df.iterrows():
    output_str = row.get('output', None)  # Get the output column value

    if isinstance(output_str, str):  # Ensure it's a valid string before loading as JSON
        try:
            # Convert the 'output' column JSON string into a dictionary
            output_data = json.loads(output_str)

            # Extract boxes, labels, and scores from JSON data
            if 'boxes' in output_data and 'labels' in output_data and 'scores' in output_data:
                boxes = output_data['boxes']
                labels = output_data['labels']
                scores = output_data['scores']

                # Decode the image data
                image_data = base64.b64decode(row['IMAGE_DATA'])
                image = Image.open(io.BytesIO(image_data)).convert("RGB")

                # Limit to top 5 classes based on scores
                if len(scores) > 0:
                    # Create a DataFrame to manage boxes, labels, and scores
                    data = pd.DataFrame({
                        'box': boxes,
                        'label': labels,
                        'score': scores
                    })

                    # Get the top 5 entries based on scores
                    top_classes = data.nlargest(5, 'score')

                    # Extract corresponding boxes, labels, and scores
                    top_boxes = top_classes['box'].tolist()
                    top_labels = top_classes['label'].tolist()
                    top_scores = top_classes['score'].tolist()

                    # Store each of the top 5 predictions as a separate row
                    for i in range(len(top_boxes)):
                        rows.append({
                            'image_data': row['IMAGE_DATA'],
                            'output': row['output'],
                            'label': top_labels[i],
                            'box': top_boxes[i],
                            'score': top_scores[i]
                        })

                    # Display the image with bounding boxes and labels
                    display_image_with_boxes(image, top_boxes, top_labels, top_scores)
                else:
                    print("No scores available to limit to top 5.")
            else:
                print("Missing keys 'boxes', 'labels', or 'scores' in the output data.")

        except json.JSONDecodeError:
            print(f"Invalid JSON in row {index}, skipping this row.")
    else:
        print(f"Invalid output type (not a string) in row {index}, skipping this row.")

# Create the final DataFrame with the collected rows (one row per label/box/score)
final_df = pd.DataFrame(rows)
session.sql("create TABLE if not exists PCB_DATASET.PUBLIC.DETECTION_OUTPUTS (\
	image_data VARCHAR(16777216),\
	output VARCHAR(16777216),\
	label NUMBER(38,0),\
	box VARIANT,\
	score FLOAT\
)").collect()

# Write the DataFrame to the Snowflake table
combined_spdf = session.create_dataframe(final_df)
combined_spdf.write.save_as_table("DETECTION_OUTPUTS", mode="overwrite")
