# Zero-Shot PCB Defect Detection

This notebook demonstrates how to use Vision-Language Models (VLMs) for zero-shot detection of defects in PCB images.

## Overview

- **Zero-Shot Learning**: No labeled training data required
- **Prompt-Based Classification**: Uses natural language prompts to identify defect types
- **Semiconductor-Specific**: Tailored for PCB and semiconductor manufacturing defects
- **Hugging Face Integration**: Leverages pre-trained Vision Transformer (ViT) models

## Setup and Imports

In [None]:
# Install required packages if needed
# Uncomment these lines if running for the first time
# !pip install torch torchvision transformers pillow matplotlib numpy scikit-learn pandas tqdm

In [None]:
import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
from typing import List, Dict, Any, Union, Optional

# Enable showing images in the notebook
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)  # Set default figure size

## Define Model and Zero-Shot Detection Classes

First, let's implement the core functionality:

In [None]:
class PCBDefectVLM:
    """PCB defect detection using Vision-Language Models."""
    
    def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
        """
        Initialize the Vision Language Model for PCB defect detection.
        
        Args:
            model_name: Hugging Face model identifier for the VLM
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        
        print(f"Loading model: {model_name}...")
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        print("Model loaded successfully!")
        
    def load_image(self, image_path: str) -> Image.Image:
        """
        Load and prepare an image for inference.
        
        Args:
            image_path: Path to the image file
            
        Returns:
            PIL Image object
        """
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found at {image_path}")
        
        return Image.open(image_path).convert("RGB")
    
    def classify(self, image: Union[str, Image.Image], categories: List[str]) -> Dict[str, float]:
        """
        Perform zero-shot classification on PCB image.
        
        Args:
            image: Path to image or PIL Image object
            categories: List of defect categories as text prompts
            
        Returns:
            Dictionary of category -> probability mappings
        """
        if isinstance(image, str):
            image = self.load_image(image)
            
        # Prepare text prompts for the model
        text_inputs = self.processor(
            text=categories,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(self.device)
        
        # Prepare image for the model
        image_inputs = self.processor(
            images=image,
            return_tensors="pt"
        ).to(self.device)
        
        # Get embeddings
        with torch.no_grad():
            image_features = self.model.get_image_features(**image_inputs)
            text_features = self.model.get_text_features(**text_inputs)
            
            # Normalize features
            image_features = image_features / image_features.norm(dim=1, keepdim=True)
            text_features = text_features / text_features.norm(dim=1, keepdim=True)
            
            # Calculate similarity scores
            logits_per_image = (100.0 * image_features @ text_features.T).squeeze(0)
            probs = logits_per_image.softmax(dim=0)
            
        # Create and return results dictionary
        results = {}
        for category, prob in zip(categories, probs.cpu().numpy()):
            results[category] = float(prob)
            
        return results
    
    def classify_batch(self, images: List[Union[str, Image.Image]], categories: List[str]) -> List[Dict[str, float]]:
        """
        Perform zero-shot classification on a batch of PCB images.
        
        Args:
            images: List of image paths or PIL Image objects
            categories: List of defect categories as text prompts
            
        Returns:
            List of dictionaries mapping categories to probabilities
        """
        # Load images if paths are provided
        loaded_images = []
        for img in images:
            if isinstance(img, str):
                loaded_images.append(self.load_image(img))
            else:
                loaded_images.append(img)
                
        # Prepare text prompts for the model
        text_inputs = self.processor(
            text=categories,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(self.device)
        
        # Prepare images for the model
        image_inputs = self.processor(
            images=loaded_images,
            return_tensors="pt",
            padding=True
        ).to(self.device)
        
        # Get embeddings
        with torch.no_grad():
            image_features = self.model.get_image_features(**image_inputs)
            text_features = self.model.get_text_features(**text_inputs)
            
            # Normalize features
            image_features = image_features / image_features.norm(dim=1, keepdim=True)
            text_features = text_features / text_features.norm(dim=1, keepdim=True)
            
            # Calculate similarity scores
            logits_per_image = (100.0 * image_features @ text_features.T)
            probs = logits_per_image.softmax(dim=-1)
            
        # Create and return results
        results = []
        for i, prob_set in enumerate(probs.cpu().numpy()):
            result = {}
            for category, prob in zip(categories, prob_set):
                result[category] = float(prob)
            results.append(result)
            
        return results

In [None]:
class PCBDefectDetector:
    """Zero-shot PCB defect detection with prompt-based categorization."""
    
    def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
        """
        Initialize the PCB defect detector.
        
        Args:
            model_name: Hugging Face model identifier for the VLM
        """
        self.vlm = PCBDefectVLM(model_name=model_name)
        self.defect_categories = []
        self.defect_prompts = {}
        
    def load_defect_categories(self, json_path: str = None) -> None:
        """
        Load defect categories and prompts from a JSON file or use defaults.
        
        Args:
            json_path: Path to the JSON file containing defect categories
        """
        if json_path and os.path.exists(json_path):
            with open(json_path, 'r') as f:
                data = json.load(f)
                
            self.defect_categories = [item['category'] for item in data['defects']]
            
            # Store the detailed prompts for each category
            self.defect_prompts = {}
            for item in data['defects']:
                self.defect_prompts[item['category']] = item['prompts']
        else:
            print("Using default defect categories...")
            # Default defect categories for PCBs
            self.defect_categories = [
                "Solder Bridge",
                "Missing Component",
                "Component Misalignment",
                "Cold Solder Joint",
                "Lifted Pad",
                "Excess Solder",
                "Insufficient Solder",
                "Cracked Solder Joint",
                "PCB Scratch",
                "Burnt Component",
                "Reversed Component",
                "Foreign Material"
            ]
            
            # Default prompts for each category
            self.defect_prompts = {
                "Solder Bridge": [
                    "solder bridging between adjacent pins",
                    "short circuit between traces or pads"
                ],
                "Missing Component": [
                    "missing electronic component",
                    "component placement area with no part installed"
                ],
                "Component Misalignment": [
                    "misaligned component on the PCB",
                    "component shifted from its correct position"
                ],
                "Cold Solder Joint": [
                    "cold solder joint",
                    "dull, grainy solder connection"
                ],
                "Lifted Pad": [
                    "pad lifted from PCB substrate",
                    "copper pad delamination"
                ],
                "Excess Solder": [
                    "too much solder on joint",
                    "solder ball or blob"
                ],
                "Insufficient Solder": [
                    "not enough solder on joint",
                    "incomplete solder coverage"
                ],
                "Cracked Solder Joint": [
                    "cracked solder connection",
                    "fracture in solder joint"
                ],
                "PCB Scratch": [
                    "scratch on PCB surface",
                    "damaged trace on board"
                ],
                "Burnt Component": [
                    "burnt or charred component",
                    "blackened electronic part"
                ],
                "Reversed Component": [
                    "component installed backwards",
                    "reversed polarity component"
                ],
                "Foreign Material": [
                    "debris on PCB surface",
                    "contaminant on circuit board"
                ]
            }
    
    def get_prompts_for_detection(self, enhance_with_domain: bool = True) -> List[str]:
        """
        Generate prompts for zero-shot detection.
        
        Args:
            enhance_with_domain: Whether to enhance prompts with domain-specific language
            
        Returns:
            List of formatted prompts for the model
        """
        if not self.defect_categories:
            self.load_defect_categories()
        
        detection_prompts = []
        
        for category in self.defect_categories:
            # Get the most generic prompt for this category
            base_prompt = self.defect_prompts[category][0]
            
            if enhance_with_domain:
                # Format with PCB/semiconductor domain knowledge
                prompt = f"A PCB with {base_prompt}"
                prompt_alt = f"A printed circuit board showing {base_prompt}"
                detection_prompts.extend([prompt, prompt_alt])
            else:
                detection_prompts.append(base_prompt)
                
        # Always add a "normal" category
        detection_prompts.append("A normal PCB with no defects")
        detection_prompts.append("A perfectly manufactured printed circuit board")
        
        return detection_prompts
    
    def detect(self, image_path: str, threshold: float = 0.2, 
               top_k: int = 3, enhance_prompts: bool = True) -> Dict[str, Any]:
        """
        Detect PCB defects in an image using zero-shot classification.
        
        Args:
            image_path: Path to the PCB image
            threshold: Confidence threshold for detection
            top_k: Number of top categories to return
            enhance_prompts: Whether to enhance prompts with domain-specific language
            
        Returns:
            Detection results with categories and confidence scores
        """
        # Get formatted prompts
        prompts = self.get_prompts_for_detection(enhance_with_domain=enhance_prompts)
        
        # Perform zero-shot classification
        raw_results = self.vlm.classify(image_path, prompts)
        
        # Post-process results to combine similar categories
        processed_results = self._process_results(raw_results)
        
        # Get top k results above threshold
        top_results = {k: v for k, v in sorted(
            processed_results.items(), 
            key=lambda item: item[1], 
            reverse=True
        ) if v >= threshold}
        
        # Limit to top k
        top_k_results = dict(list(top_results.items())[:top_k])
        
        # Determine if the PCB is defective
        is_defective = not any(k.lower().find("normal") >= 0 for k in list(top_k_results.keys())[:1])
        
        return {
            "is_defective": is_defective,
            "defects": top_k_results,
            "all_scores": processed_results
        }
    
    def _process_results(self, raw_results: Dict[str, float]) -> Dict[str, float]:
        """
        Process raw classification results to combine similar categories.
        
        Args:
            raw_results: Raw classification results
            
        Returns:
            Processed results with combined categories
        """
        processed = {}
        
        # Group by category and take maximum score
        for prompt, score in raw_results.items():
            # Extract the category from the prompt
            category = None
            for cat in self.defect_categories:
                if cat.lower() in prompt.lower():
                    category = cat
                    break
            
            # If it's a "normal" prompt
            if "normal" in prompt.lower() or "no defects" in prompt.lower():
                category = "Normal"
                
            if category:
                if category in processed:
                    processed[category] = max(processed[category], score)
                else:
                    processed[category] = score
        
        return processed
    
    def batch_detect(self, image_paths: List[str], threshold: float = 0.2,
                    top_k: int = 3, enhance_prompts: bool = True) -> List[Dict[str, Any]]:
        """
        Detect PCB defects in multiple images.
        
        Args:
            image_paths: List of paths to PCB images
            threshold: Confidence threshold for detection
            top_k: Number of top categories to return
            enhance_prompts: Whether to enhance prompts with domain-specific language
            
        Returns:
            List of detection results for each image
        """
        results = []
        for image_path in image_paths:
            result = self.detect(
                image_path=image_path,
                threshold=threshold,
                top_k=top_k,
                enhance_prompts=enhance_prompts
            )
            results.append(result)
            
        return results

## Visualization Functions

Let's define some helper functions to visualize our results:

In [None]:
def visualize_detection(image_path: str, results: Dict[str, Any], 
                        output_path: Optional[str] = None,
                        show: bool = True) -> None:
    """
    Visualize defect detection results with confidence scores.
    
    Args:
        image_path: Path to the PCB image
        results: Detection results from PCBDefectDetector
        output_path: Optional path to save the visualization
        show: Whether to display the plot
    """
    # Load image
    img = Image.open(image_path)
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
    
    # Plot image
    ax1.imshow(np.array(img))
    ax1.set_title("PCB Image")
    ax1.axis('off')
    
    # Plot defect scores
    defects = results['defects']
    categories = list(defects.keys())
    scores = list(defects.values())
    
    # Sort by score in descending order
    sorted_indices = np.argsort(scores)[::-1]
    categories = [categories[i] for i in sorted_indices]
    scores = [scores[i] for i in sorted_indices]
    
    # Set colors based on defect status
    colors = ['red' if category.lower() != "normal" else 'green' for category in categories]
    
    # Plot horizontal bar chart
    y_pos = np.arange(len(categories))
    bars = ax2.barh(y_pos, scores, color=colors, alpha=0.7)
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(categories)
    ax2.set_xlim(0, 1.0)
    ax2.set_xlabel('Confidence Score')
    
    # Determine overall status
    if results['is_defective']:
        status_text = "DEFECTIVE"
        status_color = "red"
    else:
        status_text = "NORMAL"
        status_color = "green"
        
    ax2.set_title(f"Detection Results: {status_text}", color=status_color, fontweight='bold')
    
    # Add score values
    for bar, score in zip(bars, scores):
        ax2.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2, 
                f'{score:.2f}', va='center')
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Visualization saved to {output_path}")
    
    if show:
        plt.show()
    else:
        plt.close()

def create_comparison_grid(image_paths: List[str], results: List[Dict[str, Any]],
                           output_path: Optional[str] = None,
                           grid_size: Optional[Tuple[int, int]] = None,
                           show: bool = True) -> None:
    """
    Create a grid of PCB images with their detection results.
    
    Args:
        image_paths: List of paths to PCB images
        results: List of detection results
        output_path: Optional path to save the visualization
        grid_size: Optional tuple of (rows, cols) for the grid layout
        show: Whether to display the plot
    """
    n_images = len(image_paths)
    
    if grid_size is None:
        # Calculate grid size based on number of images
        cols = min(4, n_images)
        rows = (n_images + cols - 1) // cols
    else:
        rows, cols = grid_size
    
    # Create figure
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
    
    # Handle single row or column case
    if rows == 1 and cols == 1:
        axes = np.array([[axes]])
    elif rows == 1:
        axes = axes.reshape(1, -1)
    elif cols == 1:
        axes = axes.reshape(-1, 1)
    
    # Plot each image with its top defect
    for i, (image_path, result) in enumerate(zip(image_paths, results)):
        if i >= rows * cols:
            break
            
        row, col = i // cols, i % cols
        ax = axes[row, col]
        
        # Load and display image
        img = Image.open(image_path)
        ax.imshow(np.array(img))
        
        # Get top defect
        defects = result['defects']
        if defects:
            top_defect = list(defects.keys())[0]
            top_score = list(defects.values())[0]
            
            # Set color based on defect status
            color = 'red' if result['is_defective'] else 'green'
            
            # Set title with top defect and score
            ax.set_title(f"{top_defect}\n({top_score:.2f})", color=color)
        else:
            ax.set_title("No defects detected")
            
        ax.axis('off')
    
    # Hide unused subplots
    for i in range(n_images, rows * cols):
        row, col = i // cols, i % cols
        axes[row, col].axis('off')
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Comparison grid saved to {output_path}")
    
    if show:
        plt.show()
    else:
        plt.close()

## PCB Defect Categories

Let's define the defect categories in a JSON format:

In [None]:
# Define defect categories
defect_categories = {
    "defects": [
        {
            "category": "Solder Bridge",
            "prompts": [
                "solder bridging between adjacent pins",
                "short circuit between traces or pads",
                "excess solder creating unwanted connections"
            ],
            "severity": "high"
        },
        {
            "category": "Missing Component",
            "prompts": [
                "missing electronic component",
                "component placement area with no part installed",
                "empty pad where a component should be"
            ],
            "severity": "high"
        },
        {
            "category": "Component Misalignment",
            "prompts": [
                "misaligned component on the PCB",
                "component shifted from its correct position",
                "rotated or tilted component"
            ],
            "severity": "medium"
        },
        {
            "category": "Cold Solder Joint",
            "prompts": [
                "cold solder joint",
                "dull, grainy solder connection",
                "incomplete solder wetting"
            ],
            "severity": "high"
        },
        {
            "category": "Lifted Pad",
            "prompts": [
                "pad lifted from PCB substrate",
                "copper pad delamination",
                "detached pad from board"
            ],
            "severity": "high"
        },
        {
            "category": "Excess Solder",
            "prompts": [
                "too much solder on joint",
                "solder ball or blob",
                "overflowed solder joint"
            ],
            "severity": "medium"
        },
        {
            "category": "Insufficient Solder",
            "prompts": [
                "not enough solder on joint",
                "incomplete solder coverage",
                "starved solder joint"
            ],
            "severity": "high"
        },
        {
            "category": "PCB Scratch",
            "prompts": [
                "scratch on PCB surface",
                "damaged trace on board",
                "visible gouge in PCB"
            ],
            "severity": "medium"
        },
        {
            "category": "Burnt Component",
            "prompts": [
                "burnt or charred component",
                "blackened electronic part",
                "component with burn marks"
            ],
            "severity": "high"
        },
        {
            "category": "Foreign Material",
            "prompts": [
                "debris on PCB surface",
                "contaminant on circuit board",
                "flux residue on board"
            ],
            "severity": "medium"
        }
    ]
}

# Save defect categories to a JSON file
if not os.path.exists('data'):
    os.makedirs('data')
if not os.path.exists('data/prompts'):
    os.makedirs('data/prompts')
    
with open('data/prompts/defect_categories.json', 'w') as f:
    json.dump(defect_categories, f, indent=4)
    
print("Defect categories saved to data/prompts/defect_categories.json")

## Initialize the Detector

Let's initialize our PCB defect detector:

In [None]:
# Initialize the detector with the CLIP model
detector = PCBDefectDetector(model_name="openai/clip-vit-base-patch32")

# Load defect categories
detector.load_defect_categories('data/prompts/defect_categories.json')

# Check the loaded categories
print("Loaded defect categories:")
for category in detector.defect_categories:
    print(f"- {category}")

## Analyze Your Images

Let's analyze your 2 normal and 2 defective PCB images:

In [None]:
# Define image paths - replace these with your actual image paths
normal_images = [
    "path/to/normal_pcb_1.jpg",
    "path/to/normal_pcb_2.jpg"
]

defective_images = [
    "path/to/defective_pcb_1.jpg",
    "path/to/defective_pcb_2.jpg"
]

all_images = normal_images + defective_images

# Check if all images exist
missing_images = [img for img in all_images if not os.path.exists(img)]
if missing_images:
    print("Warning: The following images could not be found:")
    for img in missing_images:
        print(f"- {img}")
    print("Please update the image paths to point to your actual images.")
else:
    print("All images found successfully!")

In [None]:
# Create results directory if it doesn't exist
if not os.path.exists('results'):
    os.makedirs('results')

# Process normal images
print("Processing normal PCB images...")
normal_results = []
for i, image_path in enumerate(normal_images):
    print(f"Processing {image_path}...")
    result = detector.detect(
        image_path=image_path,
        threshold=0.1,  # Lower threshold to see more potential defects
        top_k=5,        # Show top 5 categories
        enhance_prompts=True
    )
    normal_results.append(result)
    
    # Visualize and save the result
    output_path = f"results/normal_{i+1}_detection.png"
    visualize_detection(image_path, result, output_path)

In [None]:
# Process defective images
print("Processing defective PCB images...")
defective_results = []
for i, image_path in enumerate(defective_images):
    print(f"Processing {image_path}...")
    result = detector.detect(
        image_path=image_path,
        threshold=0.1,  # Lower threshold to see more potential defects
        top_k=5,        # Show top 5 categories
        enhance_prompts=True
    )
    defective_results.append(result)
    
    # Visualize and save the result
    output_path = f"results/defective_{i+1}_detection.png"
    visualize_detection(image_path, result, output_path)

## Compare All Images

Now let's create a comparison grid of all images:

In [None]:
# Combine all results
all_results = normal_results + defective_results

# Create a comparison grid
create_comparison_grid(
    image_paths=all_images,
    results=all_results,
    output_path="results/all_pcbs_comparison.png",
    grid_size=(2, 2)  # 2x2 grid for 4 images
)

## Summarize Results

Let's create a summary of our detection results:

In [None]:
# Create a summary table
print("Summary of PCB Defect Detection Results")
print("-" * 80)
print(f"{'Image Path':<40} {'Defective':<10} {'Top Defect':<20} {'Confidence':<10}")
print("-" * 80)

for image_path, result in zip(all_images, all_results):
    is_defective = result['is_defective']
    
    if result['defects']:
        top_defect = list(result['defects'].keys())[0]
        confidence = list(result['defects'].values())[0]
    else:
        top_defect = "None"
        confidence = 0.0
        
    # Shorten image path for display
    short_path = image_path
    if len(short_path) > 40:
        short_path = "..." + short_path[-37:]
        
    print(f"{short_path:<40} {str(is_defective):<10} {top_defect:<20} {confidence:.4f}")
    
print("-" * 80)

## Analyze Detection Accuracy

Let's analyze how well our model performed:

In [None]:
# Calculate accuracy metrics
expected_normal = [True] * len(normal_images)
expected_defective = [True] * len(defective_images)
expected_labels = [False] * len(normal_images) + [True] * len(defective_images)

predicted_labels = [result['is_defective'] for result in all_results]

# Calculate accuracy
correct = sum(1 for expected, predicted in zip(expected_labels, predicted_labels) if expected == predicted)
accuracy = correct / len(expected_labels)

# Calculate confusion matrix
true_positive = sum(1 for expected, predicted in zip(expected_labels, predicted_labels) if expected and predicted)
true_negative = sum(1 for expected, predicted in zip(expected_labels, predicted_labels) if not expected and not predicted)
false_positive = sum(1 for expected, predicted in zip(expected_labels, predicted_labels) if not expected and predicted)
false_negative = sum(1 for expected, predicted in zip(expected_labels, predicted_labels) if expected and not predicted)

print("Detection Performance Metrics")
print("-" * 40)
print(f"Accuracy: {accuracy:.2%}")
print()
print("Confusion Matrix:")
print(f"{'':20} {'Predicted Normal':20} {'Predicted Defective':20}")
print(f"{'Actual Normal':20} {true_negative:20} {false_positive:20}")
print(f"{'Actual Defective':20} {false_negative:20} {true_positive:20}")

## Advanced Analysis

Let's examine the raw detection scores to better understand the model behavior:

In [None]:
# Analyze raw detection scores
for i, (image_path, result) in enumerate(zip(all_images, all_results)):
    print(f"\nDetailed analysis for image {i+1}: {image_path}")
    print("-" * 60)
    
    # Get all scores
    all_scores = result['all_scores']
    
    # Sort by score
    sorted_scores = {k: v for k, v in sorted(all_scores.items(), key=lambda item: item[1], reverse=True)}
    
    # Print top 10 scores
    print(f"{'Category':<25} {'Score':<10}")
    print("-" * 35)
    
    for j, (category, score) in enumerate(sorted_scores.items()):
        if j >= 10:  # Limit to top 10
            break
        print(f"{category:<25} {score:.4f}")
    
    # Calculate ratio between top defect and "Normal" category
    top_defect_category = next((cat for cat in sorted_scores.keys() if cat.lower() != "normal"), None)
    if top_defect_category and "Normal" in sorted_scores:
        top_defect_score = sorted_scores[top_defect_category]
        normal_score = sorted_scores["Normal"]
        ratio = top_defect_score / normal_score if normal_score > 0 else float('inf')
        
        print("\nDefect vs Normal Analysis:")
        print(f"Top defect ({top_defect_category}): {top_defect_score:.4f}")
        print(f"Normal score: {normal_score:.4f}")
        print(f"Defect/Normal ratio: {ratio:.2f}")
        
        if ratio > 2.0:
            print("Assessment: Strong indication of defect")
        elif ratio > 1.0:
            print("Assessment: Moderate indication of defect")
        else:
            print("Assessment: Likely normal PCB")

## Threshold Tuning

Let's examine how different confidence thresholds affect the detection results:

In [None]:
# Define a range of thresholds to test
thresholds = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5]

# Test different thresholds
threshold_results = []

for threshold in thresholds:
    correct_detections = 0
    
    for i, (image_path, expected_defective) in enumerate(list(zip(normal_images, [False] * len(normal_images))) + 
                                                        list(zip(defective_images, [True] * len(defective_images)))):
        result = detector.detect(
            image_path=image_path,
            threshold=threshold,
            top_k=3,
            enhance_prompts=True
        )
        
        if result['is_defective'] == expected_defective:
            correct_detections += 1
    
    accuracy = correct_detections / len(all_images)
    threshold_results.append(accuracy)

# Plot threshold vs accuracy
plt.figure(figsize=(10, 6))
plt.plot(thresholds, threshold_results, marker='o', linestyle='-', linewidth=2)
plt.xlabel('Confidence Threshold')
plt.ylabel('Detection Accuracy')
plt.title('Effect of Confidence Threshold on Detection Accuracy')
plt.grid(True, linestyle='--', alpha=0.7)
plt.xticks(thresholds)
plt.yticks(np.arange(0, 1.1, 0.1))

# Find and mark the best threshold
best_threshold_index = threshold_results.index(max(threshold_results))
best_threshold = thresholds[best_threshold_index]
best_accuracy = threshold_results[best_threshold_index]

plt.plot(best_threshold, best_accuracy, 'ro', markersize=10)
plt.annotate(f'Best threshold: {best_threshold}\nAccuracy: {best_accuracy:.2%}', 
             xy=(best_threshold, best_accuracy),
             xytext=(best_threshold+0.05, best_accuracy-0.1),
             arrowprops=dict(facecolor='black', shrink=0.05, width=1.5))

plt.savefig('results/threshold_tuning.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Optimal threshold: {best_threshold} with accuracy: {best_accuracy:.2%}")

## Conclusion

In this notebook, we've demonstrated how to use Vision-Language Models (VLMs) for zero-shot detection of defects in PCB images. Key findings:

1. The zero-shot approach allows defect detection without traditional training data
2. The optimal confidence threshold for your PCB images was determined
3. Different defect types can be identified based on natural language descriptions
4. The system can effectively distinguish between normal and defective PCBs

### Next Steps

- Add more PCB images to improve robustness
- Fine-tune the defect category prompts for better detection
- Experiment with different VLM models to compare performance
- Consider implementing a hybrid approach with few-shot learning for specific defect types