In [None]:
import os
import time
import numpy as np
import torch
from torchvision import models, transforms
from torchvision.ops import nms
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt

import imagehash
from skimage.metrics import structural_similarity as ssim

print("--- Step 1: Differential detection pipeline configuration ---")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===================================================================
# 1. LOAD CHECKPOINT (MODEL + CLASS NAMES)
# ===================================================================
model_path = "/content/drive/MyDrive/Colab Notebooks/best_efficientnet_b0_pcb_defects.pth"

checkpoint = torch.load(model_path, map_location=device)
# checkpoint is a dict like: {"model_state_dict": ..., "class_names": ...}

# restore class names from checkpoint (preferred)
class_names = checkpoint.get("class_names", None)
if class_names is None:
    # fallback: manually define if not present
    class_names = ['missing_hole', 'mouse_bite', 'open_circuit', 'short', 'spur', 'spurious_copper']

if 'normal' in class_names:
    class_names.remove('normal')

num_classes = len(class_names)

# recreate architecture exactly as during EfficientNet training
defect_classifier = models.efficientnet_b0(weights=None)

# EfficientNet classifier replacement
in_features = defect_classifier.classifier[1].in_features
defect_classifier.classifier[1] = torch.nn.Linear(in_features, num_classes)

# load ONLY the model_state_dict from checkpoint
defect_classifier.load_state_dict(checkpoint["model_state_dict"])

defect_classifier.to(device)
defect_classifier.eval()

# ===================================================================
# 2. CONFIG: GOLDEN IMAGES, WINDOW, NORMALIZATION
# ===================================================================
golden_images_dir = "/content/drive/MyDrive/PCB_DATASET/PCB_USED/"
if not os.path.exists(golden_images_dir):
    raise FileNotFoundError(f"Reference images directory '{golden_images_dir}' not found.")

WINDOW_SIZE = 128
STRIDE = WINDOW_SIZE // 4
SIMILARITY_THRESHOLD = 0.95
CLASSIFIER_CONFIDENCE_THRESHOLD = 0.80

# same normalization as training
inference_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# ===================================================================
# 3. GOLDEN IMAGE DATABASE
# ===================================================================
def create_golden_image_database(golden_dir):
    db = []
    print("Creating reference image database...")
    filenames = os.listdir(golden_dir)
    for fname in filenames:
        path = os.path.join(golden_dir, fname)
        try:
            img = Image.open(path).convert('RGB')
            hash_val = imagehash.phash(img)
            db.append({'filename': fname, 'image': img, 'hash': hash_val})
        except Exception as e:
            print(f"Warning: Could not load '{path}'. Error: {e}")
    return db

golden_db = create_golden_image_database(golden_images_dir)
print(f"{len(golden_db)} reference images loaded successfully.")

# ===================================================================
# 4. DETECTION FUNCTIONS
# ===================================================================
def find_best_match(input_image, golden_database):
    if not golden_database:
        return None
    input_hash = imagehash.phash(input_image)
    best_match = min(golden_database, key=lambda x: input_hash - x['hash'])
    min_dist = input_hash - best_match['hash']
    print(f"Best match found: '{best_match['filename']}' (hash distance: {min_dist}).")
    if min_dist > 10:
        print("Warning: High hash distance, match may be incorrect.")
    return best_match['image']


def detect_anomalies_by_comparison(input_image, golden_image, classifier):
    if input_image.size != golden_image.size:
        raise ValueError("Input image and reference image must have the same dimensions!")

    detections = []
    start_time = time.time()

    img_width, img_height = input_image.size
    for y in range(0, img_height - WINDOW_SIZE + 1, STRIDE):
        for x in range(0, img_width - WINDOW_SIZE + 1, STRIDE):
            window_input = input_image.crop((x, y, x + WINDOW_SIZE, y + WINDOW_SIZE))
            window_golden = golden_image.crop((x, y, x + WINDOW_SIZE, y + WINDOW_SIZE))

            window_input_gray = np.array(window_input.convert('L'))
            window_golden_gray = np.array(window_golden.convert('L'))

            ssim_score, _ = ssim(window_golden_gray, window_input_gray, full=True)

            if ssim_score < SIMILARITY_THRESHOLD:
                patch_tensor = inference_transform(window_input).unsqueeze(0).to(device)
                with torch.no_grad():
                    outputs = classifier(patch_tensor)
                    probabilities = torch.nn.functional.softmax(outputs, dim=1)
                    confidence, predicted_idx = torch.max(probabilities, 1)

                if confidence.item() > CLASSIFIER_CONFIDENCE_THRESHOLD:
                    detections.append({
                        'box': [x, y, x + WINDOW_SIZE, y + WINDOW_SIZE],
                        'label': class_names[predicted_idx.item()],
                        'confidence': confidence.item()
                    })

    print(f"Initial detection completed in {time.time() - start_time:.2f}s. {len(detections)} raw anomalies found.")

    if not detections:
        return []

    boxes = torch.tensor([d['box'] for d in detections], dtype=torch.float32)
    scores = torch.tensor([d['confidence'] for d in detections], dtype=torch.float32)
    keep_indices = nms(boxes, scores, iou_threshold=0.2)

    final_detections = [detections[i] for i in keep_indices]
    print(f"{len(final_detections)} final anomalies after Non-Max Suppression.")
    return final_detections


def draw_detections_on_image(image, detections):
    img_with_boxes = image.copy()
    draw = ImageDraw.Draw(img_with_boxes)

    try:
        font = ImageFont.truetype("DejaVuSans.ttf", 32)
    except IOError:
        print("Font 'DejaVuSans.ttf' not found. Using default font (may be small).")
        font = ImageFont.load_default()

    unique_labels = list(set([d['label'] for d in detections]))
    colors = plt.cm.get_cmap('hsv', len(unique_labels) + 1)
    color_map = {
        label: tuple((np.array(colors(i)[:3]) * 255).astype(int))
        for i, label in enumerate(unique_labels)
    }

    for det in detections:
        box = det['box']
        label = det['label']
        confidence = det['confidence']

        color = color_map.get(label, (255, 50, 50))
        draw.rectangle(box, outline=color, width=5)

        text = f"{label} ({confidence:.2f})"

        try:
            text_bbox = draw.textbbox((0, 0), text, font=font)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]
        except AttributeError:
            text_width, text_height = draw.textsize(text, font=font)

        background_box = [
            box[0],
            box[1] - text_height - 5,
            box[0] + text_width + 10,
            box[1]
        ]
        draw.rectangle(background_box, fill=color)
        draw.text((box[0] + 5, box[1] - text_height - 5),
                  text, fill="white", font=font)

    return img_with_boxes

# ===================================================================
# 5. RUN PIPELINE ON EXAMPLE IMAGE (STANDALONE)
# ===================================================================
test_image_path = "/content/drive/MyDrive/PCB_DATASET/images/Mouse_bite/01_mouse_bite_05.jpg"
test_image_name = os.path.basename(test_image_path)

# Define output directory and create if it doesn't exist
output_dir = "/content/drive/MyDrive/PCB_DATASET/output_images/"
os.makedirs(output_dir, exist_ok=True)

print(f"\n--- Running pipeline on image: {test_image_name} ---")

try:
    input_image = Image.open(test_image_path).convert('RGB')

    golden_image_ref = find_best_match(input_image, golden_db)

    if golden_image_ref:
        anomalies = detect_anomalies_by_comparison(
            input_image,
            golden_image_ref,
            defect_classifier
        )

        result_image = draw_detections_on_image(input_image, anomalies)

        # Save the output image
        output_image_path = os.path.join(output_dir, f"detected_{test_image_name}")
        result_image.save(output_image_path)
        print(f"Output image saved to: {output_image_path}")

        plt.figure(figsize=(20, 15))
        plt.imshow(result_image)
        plt.title(
            f"Detected Anomalies on '{test_image_name}' by Differential Comparison",
            fontsize=20
        )
        plt.axis('off')
        plt.show()
    else:
        print("No matching golden image found.")

except FileNotFoundError:
    print(f"ERROR: Test file '{test_image_path}' not found.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")