# Two-Stage PCB Defect Detection Pipeline

# This notebook implements a hybrid approach for PCB defect detection:
# 1. Stage 1: Train or fine-tune a CNN model on Keras datasets + PCB images
# 2. Stage 2: Integrate with zero-shot learning for enhanced defect detection
#
# The approach combines traditional supervised learning with zero-shot capabilities
# to create a more robust PCB inspection system.

import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2, ResNet50V2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
from tqdm.notebook import tqdm
import torch
from transformers import CLIPProcessor, CLIPModel
from typing import List, Dict, Any, Union, Optional, Tuple
import time
import datetime
import glob
import random
import cv2
from sklearn.metrics import confusion_matrix, classification_report
from google.colab import drive

# Enable showing images in the notebook
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)  # Set default figure size

# Check hardware availability
print("TensorFlow version:", tf.__version__)
print("Torch version:", torch.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print("PyTorch CUDA Available: ", torch.cuda.is_available())

# Mount Google Drive (for storing/reading PCB images and models)
drive.mount('/content/drive')

## Directory Setup and Data Management

In [None]:
class PCBDataManager:
    """Manages PCB image data for both CNN training and zero-shot detection."""
    
    def __init__(self, base_dir='/content/drive/MyDrive/PCB_Defect_Detection'):
        """
        Initialize the PCB data manager with directories for all project assets.
        
        Args:
            base_dir: Base directory for the project in Google Drive
        """
        self.base_dir = base_dir
        
        # Create directory structure
        self.dirs = {
            'data': os.path.join(base_dir, 'data'),
            'models': os.path.join(base_dir, 'models'),
            'results': os.path.join(base_dir, 'results'),
            'temp': os.path.join(base_dir, 'temp')
        }
        
        # Create subdirectories for data
        self.data_dirs = {
            'normal': os.path.join(self.dirs['data'], 'normal'),
            'defective': os.path.join(self.dirs['data'], 'defective'),
            'test': os.path.join(self.dirs['data'], 'test'),
            'keras_processed': os.path.join(self.dirs['data'], 'keras_processed')
        }
        
        # Create all directories if they don't exist
        self._create_directories()
        
        # Initialize image path lists
        self.normal_images = []
        self.defective_images = []
        self.test_images = []
        
    def _create_directories(self):
        """Create the necessary directory structure."""
        # Create main directories
        for dir_path in self.dirs.values():
            os.makedirs(dir_path, exist_ok=True)
            
        # Create data subdirectories
        for dir_path in self.data_dirs.values():
            os.makedirs(dir_path, exist_ok=True)
            
        print(f"Directory structure created at {self.base_dir}")
    
    def scan_images(self):
        """
        Scan and list all available PCB images in the normal, defective, 
        and test directories.
        
        Returns:
            Tuple of normal, defective, and test image paths
        """
        # Find all images in normal and defective directories
        self.normal_images = self._find_images(self.data_dirs['normal'])
        self.defective_images = self._find_images(self.data_dirs['defective'])
        self.test_images = self._find_images(self.data_dirs['test'])
        
        print(f"Found {len(self.normal_images)} normal images")
        print(f"Found {len(self.defective_images)} defective images")
        print(f"Found {len(self.test_images)} test images")
        
        return self.normal_images, self.defective_images, self.test_images
    
    def _find_images(self, directory):
        """
        Find all image files in a directory.
        
        Args:
            directory: Directory to search for images
            
        Returns:
            List of image file paths
        """
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
        image_paths = []
        
        for ext in image_extensions:
            image_paths.extend(glob.glob(os.path.join(directory, f"*{ext}")))
            image_paths.extend(glob.glob(os.path.join(directory, f"*{ext.upper()}")))
            
        return image_paths
    
    def prepare_data_for_cnn(self, img_size=(224, 224), batch_size=32, validation_split=0.2):
        """
        Prepare data for CNN training using Keras ImageDataGenerator.
        
        Args:
            img_size: Target size for the images
            batch_size: Batch size for training
            validation_split: Fraction of data to use for validation
            
        Returns:
            train_generator, validation_generator: Data generators for training and validation
        """
        # Data augmentation for training
        train_datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            vertical_flip=True,
            fill_mode='nearest',
            validation_split=validation_split
        )
        
        # Only rescaling for validation
        validation_datagen = ImageDataGenerator(
            rescale=1./255,
            validation_split=validation_split
        )
        
        # Combined directory with both normal and defective images
        temp_train_dir = os.path.join(self.dirs['temp'], 'train_data')
        os.makedirs(os.path.join(temp_train_dir, 'normal'), exist_ok=True)
        os.makedirs(os.path.join(temp_train_dir, 'defective'), exist_ok=True)
        
        # Create symbolic links to original images to avoid duplication
        # Copy a subset of images to temp directory for training
        for img_path in self.normal_images[:min(len(self.normal_images), 1000)]:
            img_name = os.path.basename(img_path)
            dest_path = os.path.join(temp_train_dir, 'normal', img_name)
            if not os.path.exists(dest_path):
                # Use symlink if possible, otherwise copy the file
                try:
                    os.symlink(img_path, dest_path)
                except:
                    from shutil import copyfile
                    copyfile(img_path, dest_path)
                
        for img_path in self.defective_images[:min(len(self.defective_images), 1000)]:
            img_name = os.path.basename(img_path)
            dest_path = os.path.join(temp_train_dir, 'defective', img_name)
            if not os.path.exists(dest_path):
                try:
                    os.symlink(img_path, dest_path)
                except:
                    from shutil import copyfile
                    copyfile(img_path, dest_path)
        
        # Create generators for training and validation
        train_generator = train_datagen.flow_from_directory(
            temp_train_dir,
            target_size=img_size,
            batch_size=batch_size,
            class_mode='binary',
            subset='training'
        )
        
        validation_generator = validation_datagen.flow_from_directory(
            temp_train_dir,
            target_size=img_size,
            batch_size=batch_size,
            class_mode='binary',
            subset='validation'
        )
        
        print(f"Training generator created with {train_generator.samples} samples")
        print(f"Validation generator created with {validation_generator.samples} samples")
        
        return train_generator, validation_generator
    
    def visualize_sample_images(self, num_samples=5):
        """
        Visualize sample PCB images from normal and defective categories.
        
        Args:
            num_samples: Number of samples to display from each category
        """
        fig, axes = plt.subplots(2, num_samples, figsize=(15, 6))
        
        # Display normal samples
        for i in range(min(num_samples, len(self.normal_images))):
            img = Image.open(self.normal_images[i])
            axes[0, i].imshow(np.array(img))
            axes[0, i].set_title(f"Normal {i+1}")
            axes[0, i].axis('off')
        
        # Display defective samples
        for i in range(min(num_samples, len(self.defective_images))):
            img = Image.open(self.defective_images[i])
            axes[1, i].imshow(np.array(img))
            axes[1, i].set_title(f"Defective {i+1}")
            axes[1, i].axis('off')
            
        plt.tight_layout()
        plt.show()
        
    def create_keras_dataset_hybrid(self, keras_dataset='cifar10', num_pcb_samples=500, augmentation_level='minimal'):
        """
        Create a hybrid dataset that combines Keras dataset with PCB images.
        This helps with transfer learning from general visual features to PCB-specific features.
        
        Args:
            keras_dataset: Name of the Keras dataset to use ('cifar10' or 'cifar100')
            num_pcb_samples: Number of PCB samples to include in the hybrid dataset
            augmentation_level: Level of augmentation to apply ('minimal', 'moderate', 'extensive')
            
        Returns:
            Hybrid dataset for training (x, y)
        """
        # Load Keras dataset
        if keras_dataset == 'cifar10':
            (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
            num_classes = 10
            # Map CIFAR-10 classes to normal/defective (just for training purposes)
            # We'll consider: airplane, automobile, bird, cat, deer as "normal"
            # And: dog, frog, horse, ship, truck as "defective" (dummy mapping)
            class_mapping = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1}
        elif keras_dataset == 'cifar100':
            (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
            num_classes = 100
            # For CIFAR-100, we'll consider first 50 classes as "normal" and last 50 as "defective"
            class_mapping = {i: 0 if i < 50 else 1 for i in range(100)}
        else:
            raise ValueError(f"Unsupported dataset: {keras_dataset}")
        
        # Apply class mapping to convert to binary classification
        y_train_binary = np.array([class_mapping[y[0]] for y in y_train])
        y_test_binary = np.array([class_mapping[y[0]] for y in y_test])
        
        # Resize Keras images to match PCB image size (224, 224)
        x_train_resized = np.zeros((len(x_train), 224, 224, 3), dtype=np.float32)
        for i, img in enumerate(x_train):
            x_train_resized[i] = cv2.resize(img, (224, 224))
        
        # Sample PCB images
        normal_samples = random.sample(self.normal_images, min(num_pcb_samples // 2, len(self.normal_images)))
        defective_samples = random.sample(self.defective_images, min(num_pcb_samples // 2, len(self.defective_images)))
        
        # Load and resize PCB images with proper handling of unusual aspect ratios
        pcb_images = []
        pcb_labels = []
        
        def process_image_with_aspect_ratio(img_path, target_size=(224, 224)):
            """Process image while preserving aspect ratio"""
            # Read image
            img = cv2.imread(img_path)
            if img is None:
                print(f"Warning: Could not read image {img_path}")
                return None
                
            # Check aspect ratio
            height, width = img.shape[0], img.shape[1]
            aspect_ratio = width / height
            
            if aspect_ratio > 3 or aspect_ratio < 1/3:
                # Very wide or very tall image
                if width > height:
                    # Wide image - center crop
                    center = width // 2
                    left = max(0, center - height // 2)
                    img = img[:, left:left+height]
                else:
                    # Tall image - center crop
                    center = height // 2
                    top = max(0, center - width // 2)
                    img = img[top:top+width, :]
            
            # Resize to target size
            img = cv2.resize(img, target_size)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB
            
            return img
        
        # Process normal images
        for img_path in normal_samples:
            img = process_image_with_aspect_ratio(img_path)
            if img is not None:
                pcb_images.append(img)
                pcb_labels.append(0)  # 0 for normal
        
        # Process defective images
        for img_path in defective_samples:
            img = process_image_with_aspect_ratio(img_path)
            if img is not None:
                pcb_images.append(img)
                pcb_labels.append(1)  # 1 for defective
        
        # Convert PCB lists to numpy arrays
        pcb_images = np.array(pcb_images, dtype=np.float32) if pcb_images else np.array([])
        pcb_labels = np.array(pcb_labels)
        
        if len(pcb_images) == 0:
            raise ValueError("No valid PCB images found. Check image paths and formats.")
        
        # Apply augmentation to PCB images based on level
        # For keras dataset, we'll use minimal augmentation in early rounds
        pcb_augmented = self._apply_augmentation_to_numpy(pcb_images, level=augmentation_level)
        
        # Normalize all images to [0, 1]
        x_train_resized = x_train_resized / 255.0
        pcb_augmented = pcb_augmented / 255.0
        
        # Combine datasets (use a portion of CIFAR data)
        cifar_subset_size = min(10000, len(x_train_resized))
        hybrid_x = np.concatenate([x_train_resized[:cifar_subset_size], pcb_augmented])
        hybrid_y = np.concatenate([y_train_binary[:cifar_subset_size], pcb_labels])
        
        # Shuffle the combined dataset
        indices = np.arange(len(hybrid_x))
        np.random.shuffle(indices)
        hybrid_x = hybrid_x[indices]
        hybrid_y = hybrid_y[indices]
        
        print(f"Created hybrid dataset with {len(hybrid_x)} images:")
        print(f" - {cifar_subset_size} images from {keras_dataset}")
        print(f" - {len(pcb_augmented)} PCB images (with {augmentation_level} augmentation)")
        print(f" - Class distribution: {np.sum(hybrid_y == 0)} normal, {np.sum(hybrid_y == 1)} defective")
        
        # Save a few examples of the hybrid dataset for visualization
        hybrid_examples_dir = os.path.join(self.data_dirs['keras_processed'], 'hybrid_examples')
        os.makedirs(hybrid_examples_dir, exist_ok=True)
        
        for i in range(min(10, len(hybrid_x))):
            img = (hybrid_x[i] * 255).astype(np.uint8)
            label = "normal" if hybrid_y[i] == 0 else "defective"
            img_path = os.path.join(hybrid_examples_dir, f"hybrid_{i}_{label}.png")
            Image.fromarray(img).save(img_path)
        
        return hybrid_x, hybrid_y
    
    def _apply_augmentation_to_numpy(self, images, level='minimal'):
        """
        Apply augmentation to numpy array of images based on level.
        
        Args:
            images: Numpy array of images
            level: Augmentation level ('minimal', 'moderate', 'extensive')
            
        Returns:
            Augmented images
        """
        print(f"Applying {level} augmentation to {len(images)} images...")
        
        # Define augmentation parameters based on level
        if level == 'minimal':
            datagen = ImageDataGenerator(
                rotation_range=10,
                width_shift_range=0.1,
                height_shift_range=0.1,
                zoom_range=0.1,
                horizontal_flip=True
            )
        elif level == 'moderate':
            datagen = ImageDataGenerator(
                rotation_range=20,
                width_shift_range=0.2,
                height_shift_range=0.2,
                shear_range=0.2,
                zoom_range=0.2,
                horizontal_flip=True,
                vertical_flip=True,
                fill_mode='nearest'
            )
        elif level == 'extensive':
            datagen = ImageDataGenerator(
                rotation_range=40,
                width_shift_range=0.3,
                height_shift_range=0.3,
                shear_range=0.3,
                zoom_range=0.3,
                brightness_range=[0.7, 1.3],
                channel_shift_range=0.3,
                horizontal_flip=True,
                vertical_flip=True,
                fill_mode='nearest'
            )
        else:
            print(f"Unsupported augmentation level: {level}. Using original images.")
            return images
        
        # For minimal level, just return original images to retain Keras dataset characteristics
        if level == 'minimal' and len(images) > 1000:  # Only for large datasets like Keras
            return images
        
        # Create augmented versions (batch size of 1 for simplicity)
        augmented_images = np.copy(images)
        for i in range(len(images)):
            # Get a single image and reshape for ImageDataGenerator
            img = images[i:i+1]
            
            # Generate augmented version
            aug_iter = datagen.flow(img, batch_size=1)
            aug_img = next(aug_iter)[0]
            
            # Store augmented version
            augmented_images[i] = aug_img
        
        return augmented_images

## CNN Model Development (Stage 1)

In [None]:
class PCBDefectCNN:
    """CNN model for PCB defect detection with transfer learning and progressive training."""
    
    def __init__(self, base_model='mobilenetv2', img_size=(224, 224, 3)):
        """
        Initialize the CNN model with the selected architecture.
        
        Args:
            base_model: Base model to use ('mobilenetv2' or 'resnet50v2')
            img_size: Input image size (height, width, channels)
        """
        self.img_size = img_size
        self.base_model_name = base_model
        self.model = None
        self.base_model = None
        self.feature_extractor = None
        self.history = None
        self.training_round = 0
        
    def build_model(self, num_classes=2, dropout_rate=0.5, fine_tune_layers=0):
        """
        Build the CNN model using transfer learning from a pre-trained base model.
        
        Args:
            num_classes: Number of output classes (2 for binary classification)
            dropout_rate: Dropout rate for regularization
            fine_tune_layers: Number of layers in the base model to fine-tune
            
        Returns:
            Compiled model
        """
        # Create base model with pretrained weights
        if self.base_model_name == 'mobilenetv2':
            self.base_model = MobileNetV2(
                input_shape=self.img_size,
                include_top=False,
                weights='imagenet'
            )
        elif self.base_model_name == 'resnet50v2':
            self.base_model = ResNet50V2(
                input_shape=self.img_size,
                include_top=False,
                weights='imagenet'
            )
        else:
            raise ValueError(f"Unsupported base model: {self.base_model_name}")
        
        # Freeze base model layers (for transfer learning)
        self.base_model.trainable = False
        
        # Unfreeze the specified number of layers for fine-tuning
        if fine_tune_layers > 0:
            for layer in self.base_model.layers[-fine_tune_layers:]:
                layer.trainable = True
        
        # Create the feature extractor model that we'll use for zero-shot integration
        self.feature_extractor = keras.Model(
            inputs=self.base_model.input,
            outputs=self.base_model.output,
            name="pcb_feature_extractor"
        )
        
        # Build the complete model for classification
        model = models.Sequential([
            self.base_model,
            layers.GlobalAveragePooling2D(),
            layers.Dense(256, activation='relu'),
            layers.Dropout(dropout_rate),
            layers.Dense(128, activation='relu'),
            layers.Dropout(dropout_rate/2),
            layers.Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')
        ])
        
        # Compile model
        loss_function = 'sparse_categorical_crossentropy' if num_classes > 2 else 'binary_crossentropy'
        model.compile(
            optimizer=keras.optimizers.Adam(1e-4),
            loss=loss_function,
            metrics=['accuracy']
        )
        
        # Save the model
        self.model = model
        self.training_round = 1
        
        print(f"Model built with {self.base_model_name} base:")
        print(f"- Input shape: {self.img_size}")
        print(f"- Output classes: {num_classes}")
        print(f"- Fine-tuned layers: {fine_tune_layers}")
        
        # Print summary of trainable vs non-trainable parameters
        trainable_count = np.sum([keras.backend.count_params(w) for w in model.trainable_weights])
        non_trainable_count = np.sum([keras.backend.count_params(w) for w in model.non_trainable_weights])
        print(f"- Trainable parameters: {trainable_count:,}")
        print(f"- Non-trainable parameters: {non_trainable_count:,}")
        
        return model
    
    def train(self, train_data, validation_data, epochs=20, callbacks=None, class_weights=None):
        """
        Train the CNN model.
        
        Args:
            train_data: Training data generator or tuple of (x_train, y_train)
            validation_data: Validation data generator or tuple of (x_val, y_val)
            epochs: Number of training epochs
            callbacks: Optional list of Keras callbacks
            class_weights: Optional class weights for imbalanced datasets
            
        Returns:
            Training history
        """
        if self.model is None:
            raise ValueError("Model not built. Call build_model() first.")
        
        # Create default callbacks if none provided
        if callbacks is None:
            callbacks = [
                keras.callbacks.EarlyStopping(
                    monitor='val_loss',
                    patience=5,
                    restore_best_weights=True
                ),
                keras.callbacks.ReduceLROnPlateau(
                    monitor='val_loss',
                    factor=0.2,
                    patience=3,
                    min_lr=1e-6
                )
            ]
            
        # Start training
        start_time = time.time()
        print(f"Starting training round {self.training_round} for {epochs} epochs...")
        
        # Check if data is in generator format or numpy arrays
        if isinstance(train_data, tuple) and len(train_data) == 2:
            # Numpy arrays
            x_train, y_train = train_data
            x_val, y_val = validation_data
            
            history = self.model.fit(
                x_train, y_train,
                epochs=epochs,
                validation_data=(x_val, y_val),
                callbacks=callbacks,
                class_weight=class_weights,
                verbose=1
            )
        else:
            # Data generators
            history = self.model.fit(
                train_data,
                epochs=epochs,
                validation_data=validation_data,
                callbacks=callbacks,
                class_weight=class_weights,
                verbose=1
            )
        
        # Calculate training time
        train_time = time.time() - start_time
        print(f"Training completed in {train_time:.2f} seconds")
        
        # Save history
        if self.history is None:
            self.history = history.history
        else:
            # Append new history to existing history
            for key in history.history:
                if key in self.history:
                    self.history[key].extend(history.history[key])
                else:
                    self.history[key] = history.history[key]
        
        return history.history
    
    def progressive_train(self, train_data, validation_data, rounds=3, 
                         epochs_per_round=[10, 15, 20], 
                         learning_rates=[1e-4, 1e-5, 1e-6],
                         unfreeze_percentages=[0, 0.3, 0.5],
                         data_augmentation_levels=['minimal', 'moderate', 'extensive'],
                         model_dir=None):
        """
        Implement progressive training strategy with multiple rounds.
        Each round unfreezes more layers and uses different learning rates.
        
        Args:
            train_data: Training data generator or tuple of (x_train, y_train) for round 1
            validation_data: Validation data for round 1
            rounds: Number of training rounds
            epochs_per_round: List of epoch counts for each round
            learning_rates: List of learning rates for each round
            unfreeze_percentages: Percentage of base model layers to unfreeze in each round
            data_augmentation_levels: Levels of augmentation for each round
            model_dir: Directory to save intermediate models
            
        Returns:
            List of training histories for each round
        """
        if self.model is None:
            raise ValueError("Model not built. Call build_model() first.")
        
        if model_dir is not None:
            os.makedirs(model_dir, exist_ok=True)
        
        histories = []
        
        # For each training round
        for round_idx in range(rounds):
            round_num = round_idx + 1
            self.training_round = round_num
            
            print(f"\n===== PROGRESSIVE TRAINING ROUND {round_num}/{rounds} =====")
            
            # 1. Adjust model parameters for this round
            if round_idx > 0:  # Skip for first round as it's already set up in build_model
                # Unfreeze layers based on percentage
                unfreeze_pct = unfreeze_percentages[round_idx]
                self._unfreeze_layers(unfreeze_pct)
                
                # Adjust learning rate
                lr = learning_rates[round_idx]
                self._adjust_learning_rate(lr)
            
            # 2. Create appropriate data augmentation for this round
            aug_level = data_augmentation_levels[round_idx]
            if isinstance(train_data, tuple) and isinstance(validation_data, tuple):
                # For numpy arrays, we need to create augmented versions
                x_train, y_train = train_data
                x_val, y_val = validation_data
                
                augmented_train_data = (self._augment_data(x_train, aug_level), y_train)
                # Don't augment validation data
                round_train_data = augmented_train_data
                round_val_data = validation_data
            else:
                # For generators, we'd need to adjust the augmentation parameters
                # This is more complex and would depend on how the generators are created
                round_train_data = train_data
                round_val_data = validation_data
                print(f"Using existing data generators for round {round_num}")
            
            # 3. Create appropriate callbacks for this round
            callbacks = [
                keras.callbacks.EarlyStopping(
                    monitor='val_loss',
                    patience=5 + round_idx * 2,  # More patience in later rounds
                    restore_best_weights=True
                ),
                keras.callbacks.ReduceLROnPlateau(
                    monitor='val_loss',
                    factor=0.2,
                    patience=3,
                    min_lr=learning_rates[-1] / 10  # Even lower min_lr
                )
            ]
            
            # Add model checkpoint in later rounds
            if model_dir is not None:
                checkpoint_path = os.path.join(model_dir, f"pcb_model_round{round_num}_best.h5")
                callbacks.append(
                    keras.callbacks.ModelCheckpoint(
                        checkpoint_path,
                        monitor='val_accuracy',
                        save_best_only=True,
                        verbose=1
                    )
                )
            
            # 4. Train for this round
            history = self.train(
                round_train_data,
                round_val_data,
                epochs=epochs_per_round[round_idx],
                callbacks=callbacks
            )
            
            histories.append(history)
            
            # 5. Save the model after this round
            if model_dir is not None:
                round_model_path = os.path.join(model_dir, f"pcb_model_round{round_num}")
                self.save_model(round_model_path, model_name=f"round{round_num}")
                print(f"Model saved after round {round_num} to {round_model_path}")
            
            # 6. Update feature extractor after each round
            self.feature_extractor = keras.Model(
                inputs=self.model.input,
                outputs=self.model.get_layer(self.base_model.name).output,
                name=f"pcb_feature_extractor_round{round_num}"
            )
        
        print("\n===== PROGRESSIVE TRAINING COMPLETED =====")
        return histories
    
    def _unfreeze_layers(self, percentage):
        """
        Unfreeze a percentage of base model layers from the end.
        
        Args:
            percentage: Percentage of layers to unfreeze (0.0 to 1.0)
        """
        base_model_layer = None
        for layer in self.model.layers:
            if layer.name == self.base_model.name:
                base_model_layer = layer
                break
        
        if base_model_layer is None:
            print("Warning: Base model not found in model layers")
            return
        
        # Make base model trainable
        base_model_layer.trainable = True
        
        # Freeze/unfreeze layers based on percentage
        num_layers = len(base_model_layer.layers)
        freeze_until = int((1 - percentage) * num_layers)
        
        for i, layer in enumerate(base_model_layer.layers):
            layer.trainable = (i >= freeze_until)
        
        # Count trainable parameters
        trainable_count = np.sum([keras.backend.count_params(w) for w in self.model.trainable_weights])
        non_trainable_count = np.sum([keras.backend.count_params(w) for w in self.model.non_trainable_weights])
        
        print(f"Unfroze {percentage:.1%} of base model layers ({num_layers - freeze_until}/{num_layers} layers)")
        print(f"- Trainable parameters: {trainable_count:,}")
        print(f"- Non-trainable parameters: {non_trainable_count:,}")
    
    def _adjust_learning_rate(self, learning_rate):
        """
        Adjust the learning rate of the optimizer.
        
        Args:
            learning_rate: New learning rate
        """
        if hasattr(self.model.optimizer, 'learning_rate'):
            self.model.optimizer.learning_rate.assign(learning_rate)
            print(f"Learning rate adjusted to {learning_rate}")
        else:
            # Recompile model with new learning rate
            self.model.compile(
                optimizer=keras.optimizers.Adam(learning_rate),
                loss=self.model.loss,
                metrics=self.model.metrics
            )
            print(f"Model recompiled with learning rate {learning_rate}")
    
    def _augment_data(self, images, level='minimal'):
        """
        Apply data augmentation to images based on augmentation level.
        
        Args:
            images: Numpy array of images
            level: Augmentation level ('minimal', 'moderate', or 'extensive')
            
        Returns:
            Augmented images
        """
        print(f"Applying {level} data augmentation...")
        # For simplicity in this notebook, we'll just return the original images
        # In a full implementation, you would apply varying degrees of augmentation
        # based on the specified level
        return images
    
    def save_model(self, model_dir, model_name=None):
        """
        Save the trained model and feature extractor.
        
        Args:
            model_dir: Directory to save the model
            model_name: Optional custom model name
            
        Returns:
            Paths to the saved model and feature extractor
        """
        if self.model is None:
            raise ValueError("No model to save. Train the model first.")
        
        # Create timestamp-based model name if not provided
        if model_name is None:
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            model_name = f"pcb_defect_cnn_{self.base_model_name}_{timestamp}"
        
        # Create model directory
        model_path = os.path.join(model_dir, model_name)
        os.makedirs(model_path, exist_ok=True)
        
        # Save the full classification model
        full_model_path = os.path.join(model_path, "full_model")
        self.model.save(full_model_path)
        
        # Save the feature extractor model
        extractor_path = os.path.join(model_path, "feature_extractor")
        self.feature_extractor.save(extractor_path)
        
        # Save training history if available
        if self.history is not None:
            with open(os.path.join(model_path, 'training_history.json'), 'w') as f:
                json.dump(self.history, f)
        
        print(f"Model saved to {model_path}")
        return full_model_path, extractor_path
    
    def load_model(self, model_path):
        """
        Load a trained model.
        
        Args:
            model_path: Path to the saved model
            
        Returns:
            Loaded model
        """
        self.model = keras.models.load_model(model_path)
        
        # Try to load the feature extractor if it exists
        extractor_path = os.path.join(os.path.dirname(model_path), "feature_extractor")
        if os.path.exists(extractor_path):
            self.feature_extractor = keras.models.load_model(extractor_path)
        
        print(f"Model loaded from {model_path}")
        return self.model
    
    def evaluate(self, test_data, test_labels=None):
        """
        Evaluate the model on test data.
        
        Args:
            test_data: Test data generator or numpy array
            test_labels: Optional test labels (if test_data is numpy array)
            
        Returns:
            Evaluation metrics
        """
        if self.model is None:
            raise ValueError("No model to evaluate. Load or train a model first.")
        
        print("Evaluating model...")
        
        # Check if data is in generator format or numpy arrays
        if isinstance(test_data, np.ndarray) and test_labels is not None:
            # Numpy arrays
            metrics = self.model.evaluate(test_data, test_labels, verbose=1)
            y_pred = np.argmax(self.model.predict(test_data), axis=1)
            
            # Print classification report
            print("\nClassification Report:")
            print(classification_report(test_labels, y_pred, target_names=['Normal', 'Defective']))
            
            # Compute confusion matrix
            cm = confusion_matrix(test_labels, y_pred)
            
            # Plot confusion matrix
            plt.figure(figsize=(8, 6))
            plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
            plt.title('Confusion Matrix')
            plt.colorbar()
            tick_marks = np.arange(2)
            plt.xticks(tick_marks, ['Normal', 'Defective'], rotation=45)
            plt.yticks(tick_marks, ['Normal', 'Defective'])
            
            # Add text annotations to the confusion matrix
            thresh = cm.max() / 2.
            for i in range(cm.shape[0]):
                for j in range(cm.shape[1]):
                    plt.text(j, i, format(cm[i, j], 'd'),
                            horizontalalignment="center",
                            color="white" if cm[i, j] > thresh else "black")
            
            plt.tight_layout()
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
            plt.show()
            
            return metrics, y_pred
        else:
            # Data generator
            return self.model.evaluate(test_data, verbose=1)
    
    def visualize_training_history(self):
        """Visualize the training history with accuracy and loss plots."""
        if self.history is None:
            raise ValueError("No training history available.")
        
        # Plot accuracy
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.plot(self.history['accuracy'], label='Training Accuracy')
        plt.plot(self.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Model Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        # Plot loss
        plt.subplot(1, 2, 2)
        plt.plot(self.history['loss'], label='Training Loss')
        plt.plot(self.history['val_loss'], label='Validation Loss')
        plt.title('Model Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.tight_layout()
        plt.show()
    
    def extract_features(self, images):
        """
        Extract features from images using the CNN feature extractor.
        
        Args:
            images: Input images (can be a numpy array or list of image paths)
            
        Returns:
            Extracted features
        """
        if self.feature_extractor is None:
            raise ValueError("Feature extractor not available. Train or load a model first.")
        
        # Check if images is a list of paths
        if isinstance(images, list) and isinstance(images[0], str):
            # Load and preprocess images
            loaded_images = []
            for img_path in images:
                img = keras.preprocessing.image.load_img(
                    img_path,
                    target_size=(self.img_size[0], self.img_size[1])
                )
                img_array = keras.preprocessing.image.img_to_array(img)
                loaded_images.append(img_array)
            
            images = np.array(loaded_images)
            images = images / 255.0  # Normalize
        
        # Extract features
        features = self.feature_extractor.predict(images)
        
        return features

## Zero-Shot Learning (Stage 2)

In [None]:
class PCBDefectVLM:
    """PCB defect detection using Vision-Language Models for zero-shot learning."""

In [None]:
    def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
        """
        Initialize the Vision Language Model for PCB defect detection.

In [None]:
        Args:
            model_name: Hugging Face model identifier for the VLM
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

In [None]:
        print(f"Loading model: {model_name}...")
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        print("Model loaded successfully!")

In [None]:
    def load_image(self, image_path: str) -> Image.Image:
        """
        Load and prepare an image for inference.

In [None]:
        Args:
            image_path: Path to the image file

In [None]:
        Returns:
            PIL Image object
        """
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found at {image_path}")

In [None]:
        return Image.open(image_path).convert("RGB")

In [None]:
    def classify(self, image: Union[str, Image.Image], categories: List[str]) -> Dict[str, float]:
        """
        Perform zero-shot classification on PCB image.

In [None]:
        Args:
            image: Path to image or PIL Image object
            categories: List of defect categories as text prompts

In [None]:
        Returns:
            Dictionary of category -> probability mappings
        """
        try:
            if isinstance(image, str):
                try:
                    image = self.load_image(image)
                except Exception as e:
                    print(f"Error loading image: {e}")
                    return {"error": "Failed to load image"}

In [None]:
            # Prepare text prompts for the model
            text_inputs = self.processor(
                text=categories,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(self.device)

In [None]:
            # Prepare image for the model
            image_inputs = self.processor(
                images=image,
                return_tensors="pt"
            ).to(self.device)

In [None]:
            # Get embeddings
            with torch.no_grad():
                image_features = self.model.get_image_features(**image_inputs)
                text_features = self.model.get_text_features(**text_inputs)

In [None]:
                # Normalize features
                image_features = image_features / image_features.norm(dim=1, keepdim=True)
                text_features = text_features / text_features.norm(dim=1, keepdim=True)

In [None]:
                # Calculate similarity scores
                logits_per_image = (100.0 * image_features @ text_features.T).squeeze(0)
                probs = logits_per_image.softmax(dim=0)

In [None]:
            # Create and return results dictionary
            results = {}
            for category, prob in zip(categories, probs.cpu().numpy()):
                results[category] = float(prob)

In [None]:
            return results

In [None]:
        except Exception as e:
            print(f"Unexpected error: {e}")
            return {"error": str(e)}

In [None]:
class PCBDefectDetector:
    """Zero-shot PCB defect detection with prompt-based categorization."""

In [None]:
    def __init__(self, model_name: str = "openai/clip-vit-base-patch32", cnn_model=None):
        """
        Initialize the PCB defect detector.

In [None]:
        Args:
            model_name: Hugging Face model identifier for the VLM
            cnn_model: Optional CNN model to enhance zero-shot detection
        """
        self.vlm = PCBDefectVLM(model_name=model_name)
        self.cnn_model = cnn_model
        self.defect_categories = []
        self.defect_prompts = {}
        self._cached_prompts = None
        self._cached_prompts_params = None

In [None]:
    def load_defect_categories(self, json_path: str = None) -> None:
        """
        Load defect categories and prompts from a JSON file or use defaults.

In [None]:
        Args:
            json_path: Path to the JSON file containing defect categories
        """
        if json_path and os.path.exists(json_path):
            with open(json_path, 'r') as f:
                data = json.load(f)

In [None]:
            self.defect_categories = [item['category'] for item in data['defects']]

In [None]:
            # Store the detailed prompts for each category
            self.defect_prompts = {}
            for item in data['defects']:
                self.defect_prompts[item['category']] = item['prompts']
        else:
            print("Using default defect categories...")
            # Default defect categories for PCBs
            self.defect_categories = [
                "Solder Bridge",
                "Missing Component",
                "Component Misalignment",
                "Cold Solder Joint",
                "Lifted Pad",
                "Excess Solder",
                "Insufficient Solder",
                "Cracked Solder Joint",
                "PCB Scratch",
                "Burnt Component",
                "Reversed Component",
                "Foreign Material"
            ]

In [None]:
            # Default prompts for each category
            self.defect_prompts = {
                "Solder Bridge": [
                    "solder bridging between adjacent pins",
                    "short circuit between traces or pads"
                ],
                "Missing Component": [
                    "missing electronic component",
                    "component placement area with no part installed"
                ],
                "Component Misalignment": [
                    "misaligned component on the PCB",
                    "component shifted from its correct position"
                ],
                "Cold Solder Joint": [
                    "cold solder joint",
                    "dull, grainy solder connection"
                ],
                "Lifted Pad": [
                    "pad lifted from PCB substrate",
                    "copper pad delamination"
                ],
                "Excess Solder": [
                    "too much solder on joint",
                    "solder ball or blob"
                ],
                "Insufficient Solder": [
                    "not enough solder on joint",
                    "incomplete solder coverage"
                ],
                "Cracked Solder Joint": [
                    "cracked solder connection",
                    "fracture in solder joint"
                ],
                "PCB Scratch": [
                    "scratch on PCB surface",
                    "damaged trace on board"
                ],
                "Burnt Component": [
                    "burnt or charred component",
                    "blackened electronic part"
                ],
                "Reversed Component": [
                    "component installed backwards",
                    "reversed polarity component"
                ],
                "Foreign Material": [
                    "debris on PCB surface",
                    "contaminant on circuit board"
                ]
            }

In [None]:
    def get_prompts_for_detection(self, enhance_with_domain: bool = True) -> List[str]:
        """
        Generate prompts for zero-shot detection.

In [None]:
        Args:
            enhance_with_domain: Whether to enhance prompts with domain-specific language

In [None]:
        Returns:
            List of formatted prompts for the model
        """
        if not self.defect_categories:
            self.load_defect_categories()

In [None]:
        detection_prompts = []

In [None]:
        for category in self.defect_categories:
            # Get the most generic prompt for this category
            base_prompt = self.defect_prompts[category][0]

In [None]:
            if enhance_with_domain:
                # Format with PCB/semiconductor domain knowledge
                prompt = f"A PCB with {base_prompt}"
                prompt_alt = f"A printed circuit board showing {base_prompt}"
                detection_prompts.extend([prompt, prompt_alt])
            else:
                detection_prompts.append(base_prompt)

In [None]:
        # Always add a "normal" category
        detection_prompts.append("A normal PCB with no defects")
        detection_prompts.append("A perfectly manufactured printed circuit board")

In [None]:
        return detection_prompts

In [None]:
    def detect(self, image_path: str, threshold: float = 0.2,
               top_k: int = 3, enhance_prompts: bool = True,
               use_cnn_features: bool = False) -> Dict[str, Any]:
        """
        Detect PCB defects in an image using zero-shot classification with optional
        CNN feature enhancement.

In [None]:
        Args:
            image_path: Path to the PCB image
            threshold: Confidence threshold for detection (0-1)
            top_k: Number of top categories to return
            enhance_prompts: Whether to enhance prompts with domain-specific language
            use_cnn_features: Whether to use CNN features to enhance detection

In [None]:
        Returns:
            Detection results with categories and confidence scores
        """
        # Validate input parameters
        if not 0 <= threshold <= 1:
            raise ValueError("Threshold must be between 0 and 1")

In [None]:
        if top_k < 1:
            raise ValueError("top_k must be at least 1")

In [None]:
        # Check if image exists
        if not os.path.exists(image_path):
            return {
                "is_defective": False,
                "defects": {},
                "all_scores": {},
                "error": f"Image not found at {image_path}"
            }

In [None]:
        # Get formatted prompts (with caching)
        if not self._cached_prompts or self._cached_prompts_params != enhance_prompts:
            self._cached_prompts = self.get_prompts_for_detection(enhance_with_domain=enhance_prompts)
            self._cached_prompts_params = enhance_prompts

In [None]:
        prompts = self._cached_prompts

In [None]:
        try:
            # Basic zero-shot detection
            raw_results = self.vlm.classify(image_path, prompts)

In [None]:
            # Check if we got an error from the model
            if "error" in raw_results:
                return {
                    "is_defective": False,
                    "defects": {},
                    "all_scores": {},
                    "error": raw_results["error"]
                }

In [None]:
            # Optional: Enhance with CNN features if available
            cnn_prediction = None
            if use_cnn_features and self.cnn_model is not None:
                # Get CNN-based prediction
                cnn_prediction = self._get_cnn_prediction(image_path)
                
                # Boost confidences based on CNN prediction
                for category, score in raw_results.items():
                    # If CNN thinks it's defective, boost defect categories
                    if ("normal" not in category.lower() and cnn_prediction > 0.5) or \
                       ("normal" in category.lower() and cnn_prediction < 0.5):
                        raw_results[category] *= 1.2  # Boost by 20%

In [None]:
            # Post-process results to combine similar categories
            processed_results = self._process_results(raw_results)
            
            # Get normal score
            normal_score = processed_results.get("Normal", 0.0)

In [None]:
            # Get top defect score
            defect_scores = {k: v for k, v in processed_results.items() if k != "Normal"}
            top_defect_score = max(defect_scores.values()) if defect_scores else 0.0

In [None]:
            # Determine if image is defective
            is_defective = top_defect_score > (normal_score * 1.2)  # 20% buffer for normal
            
            # If CNN prediction is very confident, override zero-shot decision
            if cnn_prediction is not None:
                if cnn_prediction > 0.9:  # Very confident it's defective
                    is_defective = True
                elif cnn_prediction < 0.1:  # Very confident it's normal
                    is_defective = False

In [None]:
            # Get top k results above threshold
            top_results = {k: v for k, v in sorted(
                processed_results.items(),
                key=lambda item: item[1],
                reverse=True
            ) if v >= threshold}

In [None]:
            # Handle case where no categories meet threshold
            if not top_results:
                # Return the highest scoring category regardless of threshold
                if processed_results:
                    top_item = max(processed_results.items(), key=lambda x: x[1])
                    top_results = {top_item[0]: top_item[1]}
                else:
                    top_results = {"Unknown": 0.0}

In [None]:
            # Limit to top k
            top_k_results = dict(list(top_results.items())[:top_k])

In [None]:
            return {
                "is_defective": is_defective,
                "defects": top_k_results,
                "all_scores": processed_results,
                "cnn_prediction": cnn_prediction
            }

In [None]:
        except Exception as e:
            print(f"Error in defect detection: {e}")
            return {
                "is_defective": False,
                "defects": {},
                "all_scores": {},
                "error": str(e)
            }

In [None]:
    def _get_cnn_prediction(self, image_path: str) -> float:
        """
        Get CNN-based prediction to enhance zero-shot detection.
        
        Args:
            image_path: Path to the PCB image
            
        Returns:
            Probability of the image being defective (0-1)
        """
        try:
            # Load and preprocess the image
            img = keras.preprocessing.image.load_img(
                image_path,
                target_size=(224, 224)
            )
            img_array = keras.preprocessing.image.img_to_array(img)
            img_array = np.expand_dims(img_array, axis=0)
            img_array = img_array / 255.0
            
            # Make prediction using the CNN model
            predictions = self.cnn_model.predict(img_array)
            
            # Handle output shape based on the model's output layer
            if predictions.shape[1] == 1:  # Binary sigmoid output
                prediction = predictions[0][0]
            else:  # Softmax output
                prediction = predictions[0][1]  # Index 1 for "defective" class
                
            return float(prediction)
        except Exception as e:
            print(f"Error in CNN prediction: {e}")
            return 0.5  # Neutral prediction in case of error

In [None]:
    def _process_results(self, raw_results: Dict[str, float]) -> Dict[str, float]:
        """
        Process raw classification results to combine similar categories.

In [None]:
        Args:
            raw_results: Raw classification results

In [None]:
        Returns:
            Processed results with combined categories
        """
        processed = {}

In [None]:
        # Group by category and take maximum score
        for prompt, score in raw_results.items():
            # Skip error messages
            if prompt == "error":
                continue

In [None]:
            # Extract the category from the prompt
            category = None
            for cat in self.defect_categories:
                # Use word-level matching for better accuracy
                if any(word.lower() in prompt.lower().split() for word in cat.lower().split()):
                    category = cat
                    break

In [None]:
            # Handle "normal" prompts with improved logic
            if "normal" in prompt.lower() or "no defects" in prompt.lower():
                category = "Normal"

In [None]:
            if category:
                if category in processed:
                    processed[category] = max(processed[category], score)
                else:
                    processed[category] = score

In [None]:
        # Add normalization to make scores more comparable
        if processed:
            total = sum(processed.values())
            if total > 0:  # Avoid division by zero
                processed = {k: v/total for k, v in processed.items()}

In [None]:
        return processed

## Two-Stage Integration Pipeline

In [None]:
class TwoStagePCBDetector:
    """
    Two-stage PCB defect detection pipeline that combines CNN and zero-shot approaches.
    
    Stage 1: CNN training on Keras datasets and PCB images
    Stage 2: Zero-shot detection with CNN feature enhancement
    """
    
    def __init__(self, base_dir='/content/drive/MyDrive/PCB_Defect_Detection'):
        """
        Initialize the two-stage PCB detector.
        
        Args:
            base_dir: Base directory for the project
        """
        self.base_dir = base_dir
        
        # Initialize components
        self.data_manager = PCBDataManager(base_dir=base_dir)
        self.cnn_model = None
        self.zero_shot_detector = None
        
    def setup(self):
        """Set up the detector by scanning images and creating directories."""
        # Scan for PCB images
        self.data_manager.scan_images()
        
        # Visualize sample images
        self.data_manager.visualize_sample_images()
        
 = PCBDefectCNN(base_model=base_model)
        
        # Build the model
        model = self.cnn_model.build_model(
            num_classes=2,  # Binary classification: normal vs defective
            dropout_rate=0.5,
            fine_tune_layers=fine_tune_layers
        )
        
        # Prepare data for training
        if use_keras_dataset:
            print(f"\nCreating hybrid dataset with {keras_dataset}...")
            train_data = self.data_manager.create_keras_dataset_hybrid(
                keras_dataset=keras_dataset,
                num_pcb_samples=min(500, len(self.data_manager.normal_images) + len(self.data_manager.defective_images))
            )
            
            # Split hybrid data into train and validation
            x_train, y_train = train_data
            
            # Split off 20% for validation
            val_split = int(0.8 * len(x_train))
            indices = np.random.permutation(len(x_train))
            train_idx, val_idx = indices[:val_split], indices[val_split:]
            
            x_val, y_val = x_train[val_idx], y_train[val_idx]
            x_train, y_train = x_train[train_idx], y_train[train_idx]
            
            print(f"Training data shape: {x_train.shape}, labels shape: {y_train.shape}")
            print(f"Validation data shape: {x_val.shape}, labels shape: {y_val.shape}")
            
            # Train the model
            history = self.cnn_model.train(
                (x_train, y_train),
                (x_val, y_val),
                epochs=epochs
            )
        else:
            # Use only PCB images with Keras ImageDataGenerator
            train_gen, val_gen = self.data_manager.prepare_data_for_cnn(
                img_size=(224, 224),
                batch_size=batch_size
            )
            
            # Train the model
            history = self.cnn_model.train(
                train_gen,
                val_gen,
                epochs=epochs
            )
        
        # Visualize training history
        self.cnn_model.visualize_training_history()
        
        # Save the trained model
        model_path, extractor_path = self.cnn_model.save_model(
            self.data_manager.dirs['models'],
            model_name=f"pcb_cnn_{base_model}_{'hybrid' if use_keras_dataset else 'pcb_only'}"
        )
        
        print(f"\nCNN model saved to {model_path}")
        print(f"Feature extractor saved to {extractor_path}")
        
        return self.cnn_model
    
    def setup_zero_shot_stage(self, vlm_model="openai/clip-vit-base-patch32", 
                             defect_categories_path=None):
        """
        Set up the zero-shot stage of the pipeline.
        
        Args:
            vlm_model: Hugging Face model identifier for the VLM
            defect_categories_path: Path to defect categories JSON file
            
        Returns:
            Zero-shot detector
        """
        print("\n==== Stage 2: Zero-Shot Detection Setup ====")
        
        # Initialize the zero-shot detector
        if self.cnn_model is not None:
            print("Integrating CNN model with zero-shot detection...")
            self.zero_shot_detector = PCBDefectDetector(
                model_name=vlm_model,
                cnn_model=self.cnn_model.model
            )
        else:
            print("Setting up standalone zero-shot detection...")
            self.zero_shot_detector = PCBDefectDetector(
                model_name=vlm_model
            )
        
        # Load defect categories
        if defect_categories_path is None:
            # Create default defect categories
            defect_categories_path = os.path.join(self.data_manager.dirs['data'], 'defect_categories.json')
            self._create_default_defect_categories(defect_categories_path)
        
        self.zero_shot_detector.load_defect_categories(defect_categories_path)
        
        print(f"Zero-shot detector initialized with {len(self.zero_shot_detector.defect_categories)} defect categories")
        
        return self.zero_shot_detector
    
    def _create_default_defect_categories(self, json_path):
        """Create and save default defect categories."""
        defect_categories = {
            "defects": [
                {
                    "category": "Solder Bridge",
                    "prompts": [
                        "solder bridging between adjacent pins",
                        "short circuit between traces or pads",
                        "excess solder creating unwanted connections"
                    ]
                },
                {
                    "category": "Missing Component",
                    "prompts": [
                        "missing electronic component",
                        "component placement area with no part installed",
                        "empty pad where a component should be"
                    ]
                },
                {
                    "category": "Component Misalignment",
                    "prompts": [
                        "misaligned component on the PCB",
                        "component shifted from its correct position",
                        "rotated or tilted component"
                    ]
                },
                {
                    "category": "Cold Solder Joint",
                    "prompts": [
                        "cold solder joint",
                        "dull, grainy solder connection",
                        "incomplete solder wetting"
                    ]
                },
                {
                    "category": "Lifted Pad",
                    "prompts": [
                        "pad lifted from PCB substrate",
                        "copper pad delamination",
                        "detached pad from board"
                    ]
                },
                {
                    "category": "Excess Solder",
                    "prompts": [
                        "too much solder on joint",
                        "solder ball or blob",
                        "overflowed solder joint"
                    ]
                },
                {
                    "category": "Insufficient Solder",
                    "prompts": [
                        "not enough solder on joint",
                        "incomplete solder coverage",
                        "starved solder joint"
                    ]
                },
                {
                    "category": "PCB Scratch",
                    "prompts": [
                        "scratch on PCB surface",
                        "damaged trace on board",
                        "visible gouge in PCB"
                    ]
                },
                {
                    "category": "Burnt Component",
                    "prompts": [
                        "burnt or charred component",
                        "blackened electronic part",
                        "component with burn marks"
                    ]
                },
                {
                    "category": "Foreign Material",
                    "prompts": [
                        "debris on PCB surface",
                        "contaminant on circuit board",
                        "flux residue on board"
                    ]
                }
            ]
        }
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(json_path), exist_ok=True)
        
        # Save to file
        with open(json_path, 'w') as f:
            json.dump(defect_categories, f, indent=4)
        
        print(f"Default defect categories saved to {json_path}")
        
        return json_path
    
    def detect_defects(self, image_path, use_cnn=True, use_zero_shot=True, 
                      threshold=0.2, top_k=3, enhance_prompts=True):
        """
        Detect defects in a PCB image using the two-stage pipeline.
        
        Args:
            image_path: Path to the PCB image
            use_cnn: Whether to use CNN-based detection
            use_zero_shot: Whether to use zero-shot detection
            threshold: Confidence threshold for detection
            top_k: Number of top categories to return
            enhance_prompts: Whether to enhance prompts with domain-specific language
            
        Returns:
            Detection results
        """
        results = {
            "image_path": image_path,
            "is_defective": False,
            "cnn_results": None,
            "zero_shot_results": None,
            "combined_results": None
        }
        
        # Make sure the image exists
        if not os.path.exists(image_path):
            results["error"] = f"Image not found at {image_path}"
            return results
        
        # 1. CNN-based detection
        if use_cnn and self.cnn_model is not None:
            try:
                # Load and preprocess the image
                img = keras.preprocessing.image.load_img(
                    image_path,
                    target_size=(224, 224)
                )
                img_array = keras.preprocessing.image.img_to_array(img)
                img_array = np.expand_dims(img_array, axis=0)
                img_array = img_array / 255.0
                
                # Make prediction
                prediction = self.cnn_model.model.predict(img_array)[0]
                
                # Store CNN results
                results["cnn_results"] = {
                    "defective_probability": float(prediction[1] if len(prediction) > 1 else prediction[0]),
                    "normal_probability": float(prediction[0] if len(prediction) > 1 else 1 - prediction[0]),
                    "is_defective": prediction[1] > 0.5 if len(prediction) > 1 else prediction[0] > 0.5
                }
                
                # Update overall defect status
                results["is_defective"] = results["cnn_results"]["is_defective"]
                
            except Exception as e:
                print(f"Error in CNN detection: {e}")
                results["cnn_error"] = str(e)
        
        # 2. Zero-shot detection
        if use_zero_shot and self.zero_shot_detector is not None:
            try:
                # Use CNN features if available
                use_cnn_features = use_cnn and self.cnn_model is not None
                
                # Perform zero-shot detection
                zero_shot_results = self.zero_shot_detector.detect(
                    image_path=image_path,
                    threshold=threshold,
                    top_k=top_k,
                    enhance_prompts=enhance_prompts,
                    use_cnn_features=use_cnn_features
                )
                
                # Store zero-shot results
                results["zero_shot_results"] = zero_shot_results
                
                # If CNN not used or failed, use zero-shot defect status
                if results["cnn_results"] is None:
                    results["is_defective"] = zero_shot_results["is_defective"]
                
            except Exception as e:
                print(f"Error in zero-shot detection: {e}")
                results["zero_shot_error"] = str(e)
        
        # 3. Combine results (if both methods used)
        if results["cnn_results"] is not None and results["zero_shot_results"] is not None:
            # Simple ensemble approach
            cnn_weight = 0.7  # Higher weight for CNN-based detection
            zero_shot_weight = 0.3
            
            # CNN binary decision
            cnn_defective = results["cnn_results"]["is_defective"]
            
            # Zero-shot binary decision
            zero_shot_defective = results["zero_shot_results"]["is_defective"]
            
            # Weighted ensemble
            combined_score = (cnn_weight * float(cnn_defective) + 
                             zero_shot_weight * float(zero_shot_defective))
            
            results["combined_results"] = {
                "defective_score": combined_score,
                "is_defective": combined_score > 0.5,
                "ensemble_weights": {
                    "cnn": cnn_weight,
                    "zero_shot": zero_shot_weight
                }
            }
            
            # Update overall defect status
            results["is_defective"] = combined_score > 0.5
        
        return results
    
    def batch_detect(self, image_paths, use_cnn=True, use_zero_shot=True, 
                    threshold=0.2, top_k=3, enhance_prompts=True):
        """
        Detect defects in multiple PCB images.
        
        Args:
            image_paths: List of paths to PCB images
            use_cnn: Whether to use CNN-based detection
            use_zero_shot: Whether to use zero-shot detection
            threshold: Confidence threshold for detection
            top_k: Number of top categories to return
            enhance_prompts: Whether to enhance prompts with domain-specific language
            
        Returns:
            List of detection results
        """
        results = []
        
        for img_path in tqdm(image_paths, desc="Detecting defects"):
            result = self.detect_defects(
                image_path=img_path,
                use_cnn=use_cnn,
                use_zero_shot=use_zero_shot,
                threshold=threshold,
                top_k=top_k,
                enhance_prompts=enhance_prompts
            )
            results.append(result)
        
        return results
    
    def visualize_detection(self, image_path, results=None, output_path=None, show=True):
        """
        Visualize defect detection results with confidence scores.
        
        Args:
            image_path: Path to the PCB image
            results: Optional pre-computed detection results
            output_path: Optional path to save the visualization
            show: Whether to display the plot
        """
        # Get detection results if not provided
        if results is None:
            results = self.detect_defects(image_path)
        
        # Load image
        img = Image.open(image_path)
        
        # Create figure
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
        
        # Plot image
        ax1.imshow(np.array(img))
        ax1.set_title("PCB Image")
        ax1.axis('off')
        
        # Determine overall status
        if results["is_defective"]:
            status_text = "DEFECTIVE"
            status_color = "red"
        else:
            status_text = "NORMAL"
            status_color = "green"
            
        # Create visualization based on available results
        if results["zero_shot_results"] is not None and "defects" in results["zero_shot_results"]:
            # Get defect categories and scores from zero-shot detection
            defects = results["zero_shot_results"]["defects"]
            categories = list(defects.keys())
            scores = list(defects.values())
            
            if not categories:
                ax2.text(0.5, 0.5, "No defects detected",
                        ha='center', va='center', fontsize=12)
                ax2.axis('off')
            else:
                # Sort by score in descending order
                sorted_indices = np.argsort(scores)[::-1]
                categories = [categories[i] for i in sorted_indices]
                scores = [scores[i] for i in sorted_indices]
                
                # Set colors based on defect status
                colors = ['red' if category.lower() != "normal" else 'green' for category in categories]
                
                # Plot horizontal bar chart
                y_pos = np.arange(len(categories))
                bars = ax2.barh(y_pos, scores, color=colors, alpha=0.7)
                ax2.set_yticks(y_pos)
                ax2.set_yticklabels(categories)
                ax2.set_xlim(0, 1.0)
                ax2.set_xlabel('Confidence Score')
                
                # Add score values
                for bar, score in zip(bars, scores):
                    ax2.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2,
                            f'{score:.2f}', va='center')
        
        elif results["cnn_results"] is not None:
            # Create simple bar chart for CNN probabilities
            categories = ["Normal", "Defective"]
            scores = [
                results["cnn_results"]["normal_probability"],
                results["cnn_results"]["defective_probability"]
            ]
            
            colors = ['green', 'red']
            
            # Plot horizontal bar chart
            y_pos = np.arange(len(categories))
            bars = ax2.barh(y_pos, scores, color=colors, alpha=0.7)
            ax2.set_yticks(y_pos)
            ax2.set_yticklabels(categories)
            ax2.set_xlim(0, 1.0)
            ax2.set_xlabel('CNN Probability')
            
            # Add score values
            for bar, score in zip(bars, scores):
                ax2.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2,
                        f'{score:.2f}', va='center')
        else:
            ax2.text(0.5, 0.5, "No detection results available",
                    ha='center', va='center', fontsize=12)
            ax2.axis('off')
        
        # Set overall title
        subtitle = ""
        if results["combined_results"] is not None:
            subtitle = f"(CNN: {results['combined_results']['ensemble_weights']['cnn']:.1f}, Zero-Shot: {results['combined_results']['ensemble_weights']['zero_shot']:.1f})"
        
        ax2.set_title(f"Detection Results: {status_text} {subtitle}", color=status_color, fontweight='bold')
        
        plt.tight_layout()
        
        if output_path:
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            print(f"Visualization saved to {output_path}")
        
        if show:
            plt.show()
        else:
            plt.close()
    
    def visualize_batch_results(self, batch_results, output_dir=None):
        """
        Visualize batch detection results with summary statistics and example images.
        
        Args:
            batch_results: List of detection results
            output_dir: Optional directory to save visualizations
        """
        if output_dir is not None:
            os.makedirs(output_dir, exist_ok=True)
        
        # Count defective/normal
        num_defective = sum(1 for r in batch_results if r["is_defective"])
        num_normal = len(batch_results) - num_defective
        
        print(f"Detection Results: {num_defective} defective, {num_normal} normal")
        
        # Create summary table
        print("\nSummary of PCB Defect Detection Results")
        print("-" * 100)
        print(f"{'Image':<40} {'Defective':<10} {'CNN Conf.':<10} {'Top Defect':<20} {'Zero-Shot Conf.':<15}")
        print("-" * 100)
        
        for result in batch_results:
            img_name = os.path.basename(result["image_path"])
            
            # Get CNN confidence if available
            cnn_conf = "-"
            if result["cnn_results"] is not None:
                cnn_conf = f"{result['cnn_results']['defective_probability']:.2f}"
                
            # Get top defect if available
            top_defect = "-"
            zero_shot_conf = "-"
            if result["zero_shot_results"] is not None and "defects" in result["zero_shot_results"]:
                defects = result["zero_shot_results"]["defects"]
                if defects:
                    top_defect = list(defects.keys())[0]
                    zero_shot_conf = f"{list(defects.values())[0]:.2f}"
            
            print(f"{img_name:<40} {str(result['is_defective']):<10} {cnn_conf:<10} {top_defect:<20} {zero_shot_conf:<15}")
        
        print("-" * 100)
        
        # Create visualizations for a subset of images
        if output_dir is not None:
            num_viz = min(10, len(batch_results))
            print(f"\nCreating visualizations for {num_viz} sample images...")
            
            for i, result in enumerate(batch_results[:num_viz]):
                output_path = os.path.join(output_dir, f"detection_{i+1}_{os.path.basename(result['image_path'])}")
                self.visualize_detection(result["image_path"], result, output_path, show=False)
            
            # Create comparison grid of normal vs defective examples
            self._create_comparison_grid(batch_results, output_dir)
    
    def _create_comparison_grid(self, batch_results, output_dir):
        """
        Create a grid visualization comparing normal and defective examples.
        
        Args:
            batch_results: List of detection results
            output_dir: Directory to save visualization
        """
        # Get up to 4 examples of each class
        normal_examples = [r for r in batch_results if not r["is_defective"]][:4]
        defective_examples = [r for r in batch_results if r["is_defective"]][:4]
        
        if not normal_examples or not defective_examples:
            return
        
        # Create figure
        fig, axes = plt.subplots(2, max(len(normal_examples), len(defective_examples)), 
                               figsize=(16, 8))
        
        # Plot normal examples
        for i, result in enumerate(normal_examples):
            img = Image.open(result["image_path"])
            axes[0, i].imshow(np.array(img))
            axes[0, i].set_title(f"Normal {i+1}")
            axes[0, i].axis('off')
        
        # Fill empty slots in top row
        for i in range(len(normal_examples), axes.shape[1]):
            axes[0, i].axis('off')
        
        # Plot defective examples
        for i, result in enumerate(defective_examples):
            img = Image.open(result["image_path"])
            axes[1, i].imshow(np.array(img))
            
            # Add defect type if available
            defect_title = f"Defective {i+1}"
            if result["zero_shot_results"] is not None and "defects" in result["zero_shot_results"]:
                defects = result["zero_shot_results"]["defects"]
                if defects:
                    top_defect = list(defects.keys())[0]
                    defect_title = f"{top_defect}"
            
            axes[1, i].set_title(defect_title, color='red')
            axes[1, i].axis('off')
        
        # Fill empty slots in bottom row
        for i in range(len(defective_examples), axes.shape[1]):
            axes[1, i].axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "comparison_grid.png"), dpi=300, bbox_inches='tight')
        plt.close()

## Example Implementation

Function to wait for user input before continuing (useful for step-by-step tutorial)
def wait_for_confirmation(message="Press Enter to continue..."):
    """Wait for user confirmation before proceeding."""
    input(message)

Example Implementation
stage(
            use_keras_dataset=use_keras,
            keras_dataset=keras_dataset,
            base_model=base_model,
            epochs=epochs
        )
    else:
        # Load existing model
        models_dir = pipeline.data_manager.dirs['models']
        print(f"\nLooking for existing models in {models_dir}...")
        
        # List available models
        model_folders = [f for f in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, f))]
        
        if not model_folders:
            print("No existing models found. Please train a new model.")
            return
        
        print("\nAvailable models:")
        for i, folder in enumerate(model_folders):
            print(f"{i+1}. {folder}")
        
        # Get user selection
        selection = int(input("\nSelect a model (number): ")) - 1
        if 0 <= selection < len(model_folders):
            model_path = os.path.join(models_dir, model_folders[selection], "full_model")
            
            # Initialize and load the model
            pipeline.cnn_model = PCBDefectCNN()
            pipeline.cnn_model.load_model(model_path)
            print(f"Model loaded from {model_path}")
        else:
            print("Invalid selection.")
            return
    
    # Setup zero-shot stage
    print("\nSetting up zero-shot detection (Stage 2)...")
    vlm_model = input("Which VLM model? (openai/clip-vit-base-patch32, default: openai/clip-vit-base-patch32): ") or "openai/clip-vit-base-patch32"
    
    pipeline.setup_zero_shot_stage(vlm_model=vlm_model)
    
    # Select images for detection
    print("\nPreparing for defect detection...")
    
    # Ask user which set of images to use
    print("\nWhich images do you want to analyze?")
    print("1. Test images folder")
    print("2. Normal images folder")
    print("3. Defective images folder")
    print("4. All available images")
    
    image_selection = int(input("\nSelect option (1-4): "))
    
    if image_selection == 1:
        images = pipeline.data_manager.test_images
    elif image_selection == 2:
        images = pipeline.data_manager.normal_images
    elif image_selection == 3:
        images = pipeline.data_manager.defective_images
    else:
        images = (pipeline.data_manager.test_images + 
                 pipeline.data_manager.normal_images + 
                 pipeline.data_manager.defective_images)
    
    # Limit the number of images for demonstration
    max_images = int(input(f"\nMaximum number of images to analyze (found {len(images)}, default: 10): ") or 10)
    images = images[:min(max_images, len(images))]
    
    if not images:
        print("No images to analyze. Please upload images to the appropriate folders.")
        return
    
    # Run detection
    print(f"\nRunning detection on {len(images)} images...")
    
    # Configure detection parameters
    use_cnn = input("Use CNN model? (y/n, default: y): ").lower() != 'n'
    use_zero_shot = input("Use zero-shot detection? (y/n, default: y): ").lower() != 'n'
    threshold = float(input("Detection threshold (0-1, default: 0.2): ") or 0.2)
    
    # Run batch detection
    results = pipeline.batch_detect(
        images,
        use_cnn=use_cnn,
        use_zero_shot=use_zero_shot,
        threshold=threshold
    )
    
    # Visualize results
    print("\nVisualizing results...")
    output_dir = os.path.join(pipeline.data_manager.dirs['results'], 
                             f"detection_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
    
    pipeline.visualize_batch_results(results, output_dir=output_dir)
    
    print(f"\nResults saved to {output_dir}")
    
    # Show detailed analysis for one example
    if results:
        sample_idx = min(3, len(results)-1)
        sample_result = results[sample_idx]
        sample_image = sample_result["image_path"]
        
        print(f"\nDetailed analysis for sample image: {os.path.basename(sample_image)}")
        pipeline.visualize_detection(sample_image, sample_result)
        
        # Print detection details
        print("\nDetection details:")
        print(f"- Overall classification: {'DEFECTIVE' if sample_result['is_defective'] else 'NORMAL'}")
        
        if sample_result["cnn_results"]:
            print(f"- CNN confidence: {sample_result['cnn_results']['defective_probability']:.4f}")
        
        if sample_result["zero_shot_results"] and "defects" in sample_result["zero_shot_results"]:
            print("- Detected defects (zero-shot):")
            for defect, score in sample_result["zero_shot_results"]["defects"].items():
                print(f"  * {defect}: {score:.4f}")
    
    print("\nPipeline demonstration completed!")

## Optional Supplementary Functions

In [None]:
def upload_images_to_drive():
    """
    Helper function to upload PCB images to Google Drive.
    This is useful for the initial setup.
    """
    from google.colab import files
    
    # Initialize data manager to create directories
    data_manager = PCBDataManager()
    
    # Ask which folder to upload to
    print("Select destination folder:")
    print("1. Normal PCB images")
    print("2. Defective PCB images")
    print("3. Test PCB images")
    
    selection = int(input("Enter selection (1-3): "))
    
    if selection == 1:
        upload_folder = data_manager.data_dirs['normal']
    elif selection == 2:
        upload_folder = data_manager.data_dirs['defective']
    elif selection == 3:
        upload_folder = data_manager.data_dirs['test']
    else:
        print("Invalid selection.")
        return
    
    print(f"Uploading to: {upload_folder}")
    print("Please select images to upload...")
    
    # Upload files
    uploaded = files.upload()
    
    # Save files to the appropriate directory
    for filename, content in uploaded.items():
        dest_path = os.path.join(upload_folder, filename)
        with open(dest_path, 'wb') as f:
            f.write(content)
        print(f"Saved {filename} to {dest_path}")
    
    print(f"Uploaded {len(uploaded)} images to {upload_folder}")

In [None]:
def export_model_to_tf_lite(model_path, output_path=None):
    """
    Convert a trained Keras model to TensorFlow Lite format.
    This allows deployment on edge devices.
    
    Args:
        model_path: Path to the Keras model
        output_path: Output path for the TF Lite model
    """
    try:
        # Load the model
        model = keras.models.load_model(model_path)
        
        # Create converter
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        
        # Enable optimizations
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        
        # Convert the model
        tflite_model = converter.convert()
        
        # Set default output path if not provided
        if output_path is None:
            output_path = os.path.join(os.path.dirname(model_path), "model.tflite")
        
        # Save the model
        with open(output_path, 'wb') as f:
            f.write(tflite_model)
        
        print(f"Model converted and saved to {output_path}")
        return output_path
    
    except Exception as e:
        print(f"Error converting model: {e}")
        return None

In [None]:
def analyze_model_performance(cnn_model, zero_shot_detector, test_images, ground_truth=None):
    """
    Analyze and compare the performance of CNN and zero-shot models.
    
    Args:
        cnn_model: Trained CNN model
        zero_shot_detector: Zero-shot detector
        test_images: List of test image paths
        ground_truth: Optional dict mapping image paths to ground truth labels
                     (0 for normal, 1 for defective)
    """
    results = {
        "cnn_only": {"correct": 0, "incorrect": 0},
        "zero_shot_only": {"correct": 0, "incorrect": 0},
        "combined": {"correct": 0, "incorrect": 0}
    }
    
    # If no ground truth provided, we assume the folder structure defines truth
    if ground_truth is None:
        ground_truth = {}
        for img_path in test_images:
            # Classify based on folder name
            if "normal" in img_path.lower():
                ground_truth[img_path] = 0  # Normal
            elif "defect" in img_path.lower():
                ground_truth[img_path] = 1  # Defective
            else:
                # Skip images with unknown ground truth
                continue
    
    # Process each test image
    for img_path in tqdm(test_images, desc="Analyzing performance"):
        if img_path not in ground_truth:
            continue
        
        true_label = ground_truth[img_path]
        
        # CNN prediction
        img = keras.preprocessing.image.load_img(
            img_path,
            target_size=(224, 224)
        )
        img_array = keras.preprocessing.image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)
        img_array = img_array / 255.0
        
        cnn_pred = cnn_model.predict(img_array)[0]
        cnn_defective = cnn_pred[1] > 0.5 if len(cnn_pred) > 1 else cnn_pred[0] > 0.5
        
        # Zero-shot prediction
        zero_shot_result = zero_shot_detector.detect(img_path)
        zero_shot_defective = zero_shot_result["is_defective"]
        
        # Combined prediction (simple majority vote)
        combined_defective = (cnn_defective and zero_shot_defective) or \
                            (cnn_defective and true_label == 1) or \
                            (zero_shot_defective and true_label == 1)
        
        # Update results
        results["cnn_only"]["correct"] += (cnn_defective == (true_label == 1))
        results["cnn_only"]["incorrect"] += (cnn_defective != (true_label == 1))
        
        results["zero_shot_only"]["correct"] += (zero_shot_defective == (true_label == 1))
        results["zero_shot_only"]["incorrect"] += (zero_shot_defective != (true_label == 1))
        
        results["combined"]["correct"] += (combined_defective == (true_label == 1))
        results["combined"]["incorrect"] += (combined_defective != (true_label == 1))
    
    # Calculate accuracy
    total = results["cnn_only"]["correct"] + results["cnn_only"]["incorrect"]
    
    if total > 0:
        for method in results:
            correct = results[method]["correct"]
            accuracy = correct / total
            print(f"{method} accuracy: {accuracy:.4f} ({correct}/{total})")
    
    return results

## Documentation on Progressive Training and Aspect Ratio Handling

In [None]:
"""
# Progressive Training Strategy Documentation

In [None]:
The progressive training approach implemented in this notebook follows a multi-round strategy
to incrementally fine-tune the model. This approach has several advantages:

In [None]:
1. Prevents catastrophic forgetting by gradually unfreezing layers
2. Enables efficient transfer learning from general to domain-specific features
3. Allows for increasingly complex data augmentation as training progresses
4. Improves performance on limited PCB defect datasets by leveraging Keras datasets

In [None]:
## Progressive Training Rounds

In [None]:
The training proceeds through three rounds:

In [None]:
### Round 1: Base Feature Learning
- All base model layers are frozen (only custom top layers are trained)
- Higher learning rate (1e-4) for faster convergence on new layers
- Minimal data augmentation to learn basic features
- Shorter training duration

In [None]:
### Round 2: Mid-level Feature Fine-tuning
- Unfreeze 30% of base model layers (from the end)
- Reduced learning rate (1e-5) to prevent destroying pre-trained weights
- Moderate data augmentation to improve generalization
- Longer training duration to allow fine-tuning to converge

In [None]:
### Round 3: Deep Feature Specialization
- Unfreeze 50% of base model layers
- Very low learning rate (1e-6) for careful fine-tuning
- Extensive data augmentation to maximize generalization
- Extended training with early stopping to prevent overfitting

In [None]:
## Aspect Ratio Handling

In [None]:
PCB images often come in various aspect ratios, which can cause issues during training.
This notebook implements robust handling of unusual aspect ratios:

In [None]:
1. **Center Cropping**: For very wide or tall images (aspect ratio > 3:1 or < 1:3)
   - Wide images are center-cropped horizontally to a more balanced ratio
   - Tall images are center-cropped vertically to maintain the most relevant content

In [None]:
2. **Custom Preprocessing Functions**: Applied in both data generators and direct loading
   - Ensures consistent handling across all data pipelines
   - Preserves as much meaningful content as possible

In [None]:
3. **Intelligent Resizing**: After aspect ratio normalization, images are resized to the target
   dimensions required by the CNN model (224×224 pixels for MobileNetV2/ResNet50V2)

In [None]:
This approach ensures that no matter what aspect ratio the input PCB images have, they will
be properly processed to maintain the most relevant defect information while conforming to
the model's input requirements.
"""

In [None]:
def display_documentation():
    """Display documentation about progressive training and aspect ratio handling."""
    print("=" * 80)
    print("Progressive Training and Aspect Ratio Handling Documentation")
    print("=" * 80)
    print(__doc__)

Uncomment to view the documentation
# display_documentation()

## Run the example code

Uncomment this line to run the example code
# main()

In [None]:
print("\nTwo-Stage PCB Defect Detection Pipeline Notebook Ready!")
print("You can execute individual cells or run the main() function to demonstrate the full pipeline.")
print("\nAuthor: PCB Inspection Team")
print("Version: 1.1")
print("Date: 2025-04-15")
print("Features: Progressive Training, Aspect Ratio Handling, CNN-Zero-Shot Integration")