## Test Results

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
!pip install onnx onnxruntime onnxsim

Collecting onnx
  Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting onnxsim
  Downloading onnxsim-0.4.36-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.3 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m34.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m

### Definitions

In [3]:
# Anomaly Detection Evaluation Notebook

import os
import cv2
import numpy as np
import onnxruntime
from tqdm import tqdm
from sklearn.metrics import f1_score, jaccard_score
import random
import matplotlib.pyplot as plt

# === Utility Functions ===
def load_mask(path, size=(256, 256)):
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    mask = cv2.resize(mask, size)
    return (mask > 127).astype(np.uint8)

def threshold_heatmap(heatmap, threshold):
    return (heatmap >= threshold).astype(np.uint8)

def compute_metrics(gt, pred):
    gt_flat = gt.flatten()
    pred_flat = pred.flatten()
    return f1_score(gt_flat, pred_flat), jaccard_score(gt_flat, pred_flat)

def compute_image_level_f1(gt_mask, pred_mask):
    gt_label = int(np.any(gt_mask))
    pred_label = int(np.any(pred_mask))
    return f1_score([gt_label], [pred_label])

def crop_dark_edges(image_np, darkness_threshold=30, strip_width=5):
    gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
    h, w = gray.shape
    left = 0
    right = w - 1

    # Sum over small strips instead of individual columns to avoid single-pixel edges
    while left + strip_width < right:
        strip = gray[:, left:left + strip_width]
        if np.mean(strip) > darkness_threshold:
            break
        left += strip_width

    while right - strip_width > left:
        strip = gray[:, right - strip_width:right]
        if np.mean(strip) > darkness_threshold:
            break
        right -= strip_width

    return image_np[:, left:right+1]


# === INP-Former Inference ===
def run_inpformer(image, model_path, input_size=392):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    resized = cv2.resize(image, (input_size, input_size))
    normed = (resized / 255.0 - mean) / std
    transposed = np.transpose(normed, (2, 0, 1))
    input_tensor = np.expand_dims(transposed, axis=0).astype(np.float32)

    session = onnxruntime.InferenceSession(model_path)
    outputs = session.run(None, {session.get_inputs()[0].name: input_tensor})
    enc, dec = outputs[:2], outputs[2:4]

    maps = []
    for e, d in zip(enc, dec):
        e, d = e[0], d[0]
        e, d = np.transpose(e, (1, 2, 0)), np.transpose(d, (1, 2, 0))
        sim = 1 - np.sum(e * d, axis=2) / (np.linalg.norm(e, axis=2) * np.linalg.norm(d, axis=2) + 1e-8)
        sim = cv2.resize(sim, (256, 256))
        maps.append(sim)

    anomaly_map = np.mean(maps, axis=0)
    anomaly_map = cv2.GaussianBlur(anomaly_map, (5, 5), sigmaX=4)
    return anomaly_map


# === Load Random Test Samples ===
def collect_random_samples(dataset_root, max_samples=10):
    test_path = os.path.join(dataset_root, "test")
    gt_path = os.path.join(dataset_root, "ground_truth")

    image_mask_pairs = []

    for defect_type in os.listdir(test_path):
        if defect_type == "good":
            continue
        defect_folder = os.path.join(test_path, defect_type)
        gt_folder = os.path.join(gt_path, defect_type)

        for filename in os.listdir(defect_folder):
            if filename.endswith(".jpg"):
                image_path = os.path.join(defect_folder, filename)
                mask_path = os.path.join(gt_folder, filename.replace(".jpg", "_mask.jpg"))
                if os.path.exists(mask_path):
                    image_mask_pairs.append((image_path, mask_path))

    return random.sample(image_mask_pairs, min(max_samples, len(image_mask_pairs)))


# === Paths ===
dataset_root = "/content/drive/MyDrive/Neural_Networks_Project/wood_dataset/wood"

In [4]:
import os
import cv2
import numpy as np
import onnxruntime
from tqdm import tqdm
from sklearn.metrics import f1_score, jaccard_score
import random
import matplotlib.pyplot as plt
import sys
import torch


def run_inpformer(image, model_path, input_size=392):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    resized = cv2.resize(image, (input_size, input_size))
    normed = (resized / 255.0 - mean) / std
    transposed = np.transpose(normed, (2, 0, 1))
    input_tensor = np.expand_dims(transposed, axis=0).astype(np.float32)

    session = onnxruntime.InferenceSession(model_path)
    outputs = session.run(None, {session.get_inputs()[0].name: input_tensor})
    enc, dec = outputs[:2], outputs[2:4]

    maps = []
    for e, d in zip(enc, dec):
        e, d = e[0], d[0]
        e, d = np.transpose(e, (1, 2, 0)), np.transpose(d, (1, 2, 0))
        sim = 1 - np.sum(e * d, axis=2) / (np.linalg.norm(e, axis=2) * np.linalg.norm(d, axis=2) + 1e-8)
        sim = cv2.resize(sim, (256, 256))
        maps.append(sim)

    anomaly_map = np.mean(maps, axis=0)
    anomaly_map = cv2.GaussianBlur(anomaly_map, (5, 5), sigmaX=4)
    return anomaly_map


def load_efficientad_model(model_path):
    sys.path.append('/content/drive/MyDrive/Neural_Networks_Project/EfficientAD')
    from common import get_pdn_small, get_autoencoder
    from efficientad import teacher_normalization, map_normalization, predict
    from torch.serialization import add_safe_globals
    import torch.nn as nn
    from torchvision import transforms as T

    add_safe_globals({"Sequential": nn.Sequential})
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weights_dir = model_path

    teacher = torch.load(f"{weights_dir}/teacher_final.pth", map_location=device, weights_only=False)
    student = torch.load(f"{weights_dir}/student_final.pth", map_location=device, weights_only=False)
    autoencoder = torch.load(f"{weights_dir}/autoencoder_final.pth", map_location=device, weights_only=False)

    teacher.to(device).eval()
    student.to(device).eval()
    autoencoder.to(device).eval()

    transform = T.Compose([
        T.Resize((256, 256)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
    ])

    return teacher, student, autoencoder, predict, 0.5, 0.1, 0.05, 0.95, 0.05, 0.95, transform, device


def run_efficientad(image, model_objs):
    from PIL import Image as PILImage
    (teacher, student, autoencoder, predict, teacher_mean, teacher_std,
     q_st_start, q_st_end, q_ae_start, q_ae_end, transform, device) = model_objs

    image_pil = PILImage.fromarray(image)
    img_tensor = transform(image_pil).unsqueeze(0).to(device)

    map_combined, _, _ = predict(img_tensor, teacher, student, autoencoder,
                                 teacher_mean, teacher_std,
                                 q_st_start, q_st_end, q_ae_start, q_ae_end)

    heatmap = map_combined.squeeze().cpu().numpy()
    return 1.0 - heatmap

def run_glass(image_batch, model_path, input_size=(256, 256), expected_batch_size=8):
    from torchvision import transforms
    from PIL import Image as PILImage

    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Convert list of images to a batch tensor
    img_tensors = []
    for image in image_batch: # 'image_batch' is the input argument here
         image_pil = PILImage.fromarray(image)
         img_tensors.append(transform(image_pil).unsqueeze(0))

    # Stack tensors to create the batch
    # Ensure tensors are on the same device if using GPU, though ONNX Runtime handles device
    img_batch_tensor = torch.cat(img_tensors, dim=0).numpy()

    # Initialize ONNX Runtime session
    session = onnxruntime.InferenceSession(model_path)

    # Run inference
    # The input name 'input' is from the traceback. Verify with the actual model if needed.
    # The model expects a batch size of 8. The batch_tensor is explicitly created with
    # the size of the list of images (padded to 8 in get_global_heatmap_stats).
    outputs = session.run(None, {"input": img_batch_tensor})
    # The output is expected to be a batch of heatmaps (B, H, W) or (B, 1, H, W).
    # Assuming (B, H, W) or squeezing the channel dimension if it exists.
    heatmaps = outputs[0] # outputs is a list, outputs[0] is the tensor

    return heatmaps # Return batch of heatmaps




In [13]:
from sklearn.metrics import jaccard_score
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
from sklearn.metrics import f1_score # Ensure f1_score is imported if not already


def evaluate_image_level(dataset_root, model_name, model_path, threshold=0.5, sample_count=15, topk_percent=0.007,
                         global_min=0.0, global_max=0.5):
    test_path = os.path.join(dataset_root, "test")
    gt_path = os.path.join(dataset_root, "ground_truth")
    image_label_pairs = []

    for defect_type in os.listdir(test_path):
        defect_folder = os.path.join(test_path, defect_type)
        is_defective = defect_type != "good"
        gt_folder = os.path.join(gt_path, defect_type) if is_defective else None

        for filename in os.listdir(defect_folder):
            if filename.endswith(".jpg"):
                image_path = os.path.join(defect_folder, filename)
                mask_path = os.path.join(gt_folder, filename.replace(".jpg", "_mask.jpg")) if is_defective else None
                # Store image_path, true_label, mask_path for later retrieval
                image_label_pairs.append((image_path, int(is_defective), mask_path))

    samples = random.sample(image_label_pairs, min(sample_count, len(image_label_pairs)))

    if model_name == "EfficientAD":
        efficientad_model = load_efficientad_model(model_path)
    else:
        efficientad_model = None # Ensure this is None if not EfficientAD

    y_true = []
    y_pred = []
    processed_samples_info = [] # Store info to process after batch inference for GLASS

    # --- Batch processing for GLASS model ---
    if model_name == "GLASS":
        expected_batch_size = 8 # Fixed batch size for the GLASS model
        current_batch_images_rgb_cropped_resized = [] # Store cropped and resized RGB images for batch input
        current_batch_info = [] # Store (true_label, mask_path, cropped_image_resized_for_vis) for this batch

        for i, (image_path, true_label, mask_path) in enumerate(tqdm(samples, desc=f"Processing {model_name} samples (Batching)")):
            image = cv2.imread(image_path)
            if image is None:
                print(f"Warning: Could not load image from {image_path}. Skipping.")
                continue

            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            # Apply cropping here
            cropped_image_rgb = crop_dark_edges(image_rgb)

            # Resize the *cropped* image for model input and visualization
            # The run_glass function resizes to (256, 256) internally,
            # but we need a consistent size *after* cropping for batching
            # and for generating a visualization image that matches heatmap dimensions.
            # Let's resize the cropped image to the expected input size (e.g. 256x256) for the batch.
            # We'll also use this as the image for visualization.
            cropped_image_resized_for_vis = cv2.resize(cropped_image_rgb, (256, 256))

            current_batch_images_rgb_cropped_resized.append(cropped_image_resized_for_vis) # Add cropped and resized image to batch
            current_batch_info.append((true_label, mask_path, cropped_image_resized_for_vis)) # Store labels/paths and the image used for vis

            # If batch is full or it's the last sample
            if len(current_batch_images_rgb_cropped_resized) == expected_batch_size or i == len(samples) - 1:
                original_batch_size = len(current_batch_images_rgb_cropped_resized)
                # Pad the batch with dummy images (black image of the expected input size)
                dummy_image = np.zeros((256, 256, 3), dtype=np.uint8) # Match expected input dims (256, 256, 3)

                while len(current_batch_images_rgb_cropped_resized) < expected_batch_size:
                     current_batch_images_rgb_cropped_resized.append(dummy_image)


                try:
                    # Pass the batch of cropped and resized images to run_glass
                    # run_glass will perform its own internal resizing if needed, but
                    # providing it with images already at a standard size like 256x256
                    # can sometimes align better with how models were trained/exported.
                    # Assuming the GLASS model's internal resize handles this correctly.
                    heatmap_batch = run_glass(current_batch_images_rgb_cropped_resized, model_path, input_size=(256, 256))


                    # Process heatmaps for the *real* images in the batch
                    for j in range(original_batch_size):
                        heatmap = heatmap_batch[j] # Get heatmap for the j-th image in the current batch
                        true_label, mask_path, image_resized_for_vis = current_batch_info[j] # Get the corresponding info (including cropped/resized image)

                        # Ensure heatmap is 2D before flattening and processing
                        if heatmap.ndim > 2:
                             heatmap = np.squeeze(heatmap)
                             if heatmap.ndim > 2:
                                  print(f"Warning: Heatmap still has unexpected dimensions ({heatmap.ndim}) after squeeze. Taking first channel slice.")
                                  # Fallback to slicing if squeeze wasn't enough, assuming channel is last dim
                                  heatmap = heatmap[:, :, 0] if heatmap.shape[-1] > 1 else heatmap.squeeze()


                        # The heatmap output should ideally match the input size (256, 256).
                        # Verify and resize if necessary for visualization/metrics.
                        if heatmap.shape[:2] != (256, 256):
                            print(f"Warning: Heatmap has shape {heatmap.shape}. Resizing to (256, 256) for visualization.")
                            heatmap = cv2.resize(heatmap, (256, 256), interpolation=cv2.INTER_LINEAR) # Use linear for heatmaps


                        processed_samples_info.append({
                            'image_resized_for_vis': image_resized_for_vis, # This is now the cropped and resized image
                            'heatmap': heatmap,
                            'true_label': true_label,
                            'mask_path': mask_path,
                            'image_path': samples[i - original_batch_size + j][0] # Get original path for logging
                        })

                except Exception as e:
                    # Get the path of the first image in the problematic batch for logging
                    first_image_path_in_batch = samples[i - original_batch_size + (0 if original_batch_size > 0 else 0)][0]
                    print(f"Error processing GLASS batch starting with image index {i - original_batch_size + 1} ({first_image_path_in_batch}): {e}")
                    # Continue to the next batch

                # Reset the batch lists
                current_batch_images_rgb_cropped_resized = []
                current_batch_info = []

        # Now iterate through processed_samples_info for visualization and metrics
        print("\nProcessing individual samples for visualization and metrics...")
        for sample_info in tqdm(processed_samples_info, desc="Generating visualizations"):
            image_resized_for_vis = sample_info['image_resized_for_vis'] # This is the cropped and resized image
            heatmap = sample_info['heatmap']
            true_label = sample_info['true_label']
            mask_path = sample_info['mask_path']
            image_path = sample_info['image_path'] # Use for logging

            flat = heatmap.flatten()
            if len(flat) == 0:
                 print(f"Warning: Heatmap for {image_path} is empty. Skipping metrics and visualization.")
                 continue

            k = max(1, int(len(flat) * topk_percent))
            k = min(k, len(flat)) # Ensure k is within bounds
            if k > 0:
                 topk_mean = np.mean(np.partition(flat, -k)[-k:])
            else:
                 topk_mean = 0.0

            pred_label = int(topk_mean > threshold)

            y_true.append(true_label)
            y_pred.append(pred_label)

            # Visualization using global min/max
            range_val = global_max - global_min
            if range_val <= 1e-8:
                 print(f"Warning: Global min ({global_min}) and max ({global_max}) are too close. Using default range [0, 1] for visualization stretching for {image_path}.")
                 heatmap_vis = np.clip(heatmap, 0, 1)
            else:
                 heatmap_vis = (heatmap - global_min) / range_val
                 heatmap_vis = np.clip(heatmap_vis, 0, 1)

            heatmap_color = cv2.applyColorMap((heatmap_vis * 255).astype(np.uint8), cv2.COLORMAP_JET)
            heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)

            # The heatmap is already resized to (256, 256).
            # The image_resized_for_vis is also (256, 256).
            # No resizing needed for the heatmap_color before blending if dimensions match.
            if heatmap_color.shape[:2] != image_resized_for_vis.shape[:2]:
                 print(f"Warning: Heatmap color shape {heatmap_color.shape[:2]} does not match visualization image shape {image_resized_for_vis.shape[:2]}. Resizing heatmap_color for overlay.")
                 heatmap_color_resized = cv2.resize(heatmap_color, (image_resized_for_vis.shape[1], image_resized_for_vis.shape[0]))
            else:
                 heatmap_color_resized = heatmap_color


            # Ensure dtypes are uint8 for blending
            if image_resized_for_vis.dtype != np.uint8:
                 image_resized_for_vis = image_resized_for_vis.astype(np.uint8)
            if heatmap_color_resized.dtype != np.uint8:
                 heatmap_color_resized = heatmap_color_resized.astype(np.uint8)


            overlay = cv2.addWeighted(image_resized_for_vis, 0.6, heatmap_color_resized, 0.4, 0)

            # Plotting
            # Note: Ground truth mask is loaded at size (256, 256).
            # The predicted mask for IoU calculation should also be (256, 256).
            # Since heatmap is already resized to (256, 256), this is fine.
            if true_label == 1 and mask_path:
                gt_mask = load_mask(mask_path, size=(256, 256)) # Ensure mask is 256x256
                # Use the already resized heatmap (256, 256) for thresholding
                pred_mask_thresholded = (heatmap > threshold).astype(np.uint8)

                if gt_mask is not None and gt_mask.shape == pred_mask_thresholded.shape:
                     iou = jaccard_score(gt_mask.flatten(), pred_mask_thresholded.flatten())
                     iou_title = f"IoU={iou:.2f}"
                else:
                     iou_title = "Mask not loaded or size mismatch"
                     if gt_mask is not None and gt_mask.shape != pred_mask_thresholded.shape:
                          print(f"Warning: Ground truth mask shape {gt_mask.shape} does not match predicted mask shape {pred_mask_thresholded.shape} for {mask_path}. Cannot compute IoU.")
                          gt_mask = None

                fig, ax = plt.subplots(1, 3, figsize=(12, 4))
                ax[0].imshow(image_resized_for_vis) # Now plots the cropped and resized image
                ax[0].set_title(f"Image: {'Defect' if true_label else 'Good'}")
                ax[0].axis("off")

                ax[1].imshow(overlay) # Overlay is on the cropped and resized image
                ax[1].set_title(f"Prediction: {'Defect' if pred_label else 'Good'} (score={topk_mean:.4f})")
                ax[1].axis("off")

                if gt_mask is not None:
                     ax[2].imshow(gt_mask, cmap="gray", vmin=0, vmax=1)
                     ax[2].set_title(f"Ground Truth\n{iou_title}")
                     ax[2].axis("off")
                else:
                     ax[2].set_title(f"Ground Truth Mask Missing\n{iou_title}")
                     ax[2].axis("off")

                plt.tight_layout()
                plt.show()
            else:
                fig, ax = plt.subplots(1, 2, figsize=(8, 4))
                ax[0].imshow(image_resized_for_vis) # Now plots the cropped and resized image
                ax[0].set_title(f"Image: {'Defect' if true_label else 'Good'}")
                ax[0].axis("off")

                ax[1].imshow(overlay) # Overlay is on the cropped and resized image
                ax[1].set_title(f"Prediction: {'Defect' if pred_label else 'Good'} (score={topk_mean:.4f})")
                ax[1].axis("off")
                plt.tight_layout()
                plt.show()


    # --- Single image processing for other models (INP-Former, EfficientAD) ---
    else:
        for image_path, true_label, mask_path in tqdm(samples, desc=f"Processing {model_name} samples"):
            image = cv2.imread(image_path)
            if image is None:
                print(f"Warning: Could not load image from {image_path}. Skipping.")
                continue

            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            # Apply cropping here
            cropped_image_rgb = crop_dark_edges(image_rgb)

            # Resize the *cropped* image for visualization before processing
            image_resized_for_vis = cv2.resize(cropped_image_rgb, (256, 256)) # Resize the cropped image


            # Pass the cropped image to the model function
            # Note: run_inpformer and run_efficientad have their own internal resizing logic.
            # Passing the *cropped* image is correct, they will handle their specific input size.
            if model_name == "INP-Former":
                heatmap = run_inpformer(cropped_image_rgb, model_path) # Pass cropped image
            elif model_name == "EfficientAD":
                 heatmap = run_efficientad(cropped_image_rgb, efficientad_model) # Pass cropped image
            else:
                # This case is handled by the batching logic above, but kept for clarity
                raise ValueError(f"Unsupported model: {model_name}")

            # Ensure heatmap is 2D before flattening and processing
            if heatmap.ndim > 2:
                 heatmap = np.squeeze(heatmap)
                 if heatmap.ndim > 2:
                      print(f"Warning: Heatmap for {image_path} still has unexpected dimensions ({heatmap.ndim}) after squeeze. Taking first channel slice.")
                      # Fallback to slicing if squeeze wasn't enough, assuming channel is last dim
                      heatmap = heatmap[:, :, 0] if heatmap.shape[-1] > 1 else heatmap.squeeze()


            # Resize heatmap to match visualization size (256, 256)
            # For INP-Former/EfficientAD, the output is expected at 256x256.
            # We still perform a check and resize here for safety.
            if heatmap.shape[:2] != (256, 256):
                 print(f"Warning: Heatmap for {image_path} has shape {heatmap.shape}. Resizing to (256, 256) for visualization.")
                 heatmap = cv2.resize(heatmap, (256, 256), interpolation=cv2.INTER_LINEAR)


            flat = heatmap.flatten()
            if len(flat) == 0:
                 print(f"Warning: Heatmap for {image_path} is empty after processing. Skipping metrics and visualization for this image.")
                 continue


            k = max(1, int(len(flat) * topk_percent))
            k = min(k, len(flat)) # Ensure k is within bounds
            if k > 0:
                topk_mean = np.mean(np.partition(flat, -k)[-k:])
            else:
                topk_mean = 0.0

            pred_label = int(topk_mean > threshold)

            y_true.append(true_label)
            y_pred.append(pred_label)

            # Visualization using global min/max
            range_val = global_max - global_min
            if range_val <= 1e-8:
                 print(f"Warning: Global min ({global_min}) and max ({global_max}) are too close. Using default range [0, 1] for visualization stretching for {image_path}.")
                 heatmap_vis = np.clip(heatmap, 0, 1)
            else:
                 heatmap_vis = (heatmap - global_min) / range_val
                 heatmap_vis = np.clip(heatmap_vis, 0, 1)


            heatmap_color = cv2.applyColorMap((heatmap_vis * 255).astype(np.uint8), cv2.COLORMAP_JET)
            heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)

            # The heatmap_color is already (256, 256, 3).
            # The image_resized_for_vis is also (256, 256, 3).
            # No resizing needed for the heatmap_color before blending if dimensions match.
            if heatmap_color.shape[:2] != image_resized_for_vis.shape[:2]:
                 print(f"Warning: Heatmap color shape {heatmap_color.shape[:2]} does not match visualization image shape {image_resized_for_vis.shape[:2]}. Resizing heatmap_color for overlay.")
                 heatmap_color_resized = cv2.resize(heatmap_color, (image_resized_for_vis.shape[1], image_resized_for_vis.shape[0]))
            else:
                 heatmap_color_resized = heatmap_color

            # Ensure dtypes are uint8 for blending
            if image_resized_for_vis.dtype != np.uint8:
                 image_resized_for_vis = image_resized_for_vis.astype(np.uint8)
            if heatmap_color_resized.dtype != np.uint8:
                 heatmap_color_resized = heatmap_color_resized.astype(np.uint8)


            overlay = cv2.addWeighted(image_resized_for_vis, 0.6, heatmap_color_resized, 0.4, 0)

            # Plotting
            # Note: Ground truth mask is loaded at size (256, 256).
            # The predicted mask for IoU calculation should also be (256, 256).
            # Since heatmap is already resized to (256, 256), this is fine.
            if true_label == 1 and mask_path:
                gt_mask = load_mask(mask_path, size=(256, 256)) # Ensure mask is 256x256
                # Use the already resized heatmap (256, 256) for thresholding
                pred_mask_thresholded = (heatmap > threshold).astype(np.uint8)

                if gt_mask is not None and gt_mask.shape == pred_mask_thresholded.shape:
                     iou = jaccard_score(gt_mask.flatten(), pred_mask_thresholded.flatten())
                     iou_title = f"IoU={iou:.2f}"
                else:
                     iou_title = "Mask not loaded or size mismatch"
                     if gt_mask is not None and gt_mask.shape != pred_mask_thresholded.shape:
                          print(f"Warning: Ground truth mask shape {gt_mask.shape} does not match predicted mask shape {pred_mask_thresholded.shape} for {mask_path}. Cannot compute IoU.")
                          gt_mask = None


                fig, ax = plt.subplots(1, 3, figsize=(12, 4))
                ax[0].imshow(image_resized_for_vis) # Now plots the cropped and resized image
                ax[0].set_title(f"Image: {'Defect' if true_label else 'Good'}")
                ax[0].axis("off")

                ax[1].imshow(overlay) # Overlay is on the cropped and resized image
                ax[1].set_title(f"Prediction: {'Defect' if pred_label else 'Good'} (score={topk_mean:.4f})")
                ax[1].axis("off")

                if gt_mask is not None:
                     ax[2].imshow(gt_mask, cmap="gray", vmin=0, vmax=1)
                     ax[2].set_title(f"Ground Truth\n{iou_title}")
                     ax[2].axis("off")
                else:
                     ax[2].set_title(f"Ground Truth Mask Missing\n{iou_title}")
                     ax[2].axis("off")

                plt.tight_layout()
                plt.show()
            else:
                fig, ax = plt.subplots(1, 2, figsize=(8, 4))
                ax[0].imshow(image_resized_for_vis) # Now plots the cropped and resized image
                ax[0].set_title(f"Image: {'Defect' if true_label else 'Good'}")
                ax[0].axis("off")

                ax[1].imshow(overlay) # Overlay is on the cropped and resized image
                ax[1].set_title(f"Prediction: {'Defect' if pred_label else 'Good'} (score={topk_mean:.4f})")
                ax[1].axis("off")
                plt.tight_layout()
                plt.show()


    # Compute final F1 score
    if not y_true:
        print("\n=== Image-level F1 Score ===")
        print("No samples were processed. F1 Score cannot be computed.")
        return

    if len(y_true) != len(y_pred):
        print(f"Warning: Mismatch between number of true labels ({len(y_true)}) and predicted labels ({len(y_pred)}). This should not happen.")
        min_len = min(len(y_true), len(y_pred))
        y_true = y_true[:min_len]
        y_pred = y_pred[:min_len]
        if min_len == 0:
             print("No valid label pairs to compute F1 score.")
             return


    score = f1_score(y_true, y_pred)
    print("\n=== Image-level F1 Score ===")
    print(f"F1 Score: {score:.4f}")

In [6]:
def get_global_heatmap_stats(dataset_root, model_name, model_path, sample_count=10):
    samples = collect_random_samples(dataset_root, max_samples=sample_count)

    if not samples:
        print("Warning: No defective samples found or processed. Returning default stats.")
        return 0.0, 1.0

    all_heatmaps = []

    if model_name == "EfficientAD":
        model_objs = load_efficientad_model(model_path)
        for image_path, mask_path in tqdm(samples, desc=f"Processing {model_name} samples"):
             image = cv2.imread(image_path)
             if image is None:
                 print(f"Warning: Could not load image from {image_path}. Skipping.")
                 continue
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
             image = crop_dark_edges(image)
             heatmap = run_efficientad(image, model_objs)
             # Ensure heatmap is 2D before flattening
             if heatmap.ndim > 2:
                  heatmap = heatmap.squeeze()
             all_heatmaps.append(heatmap.flatten())

    elif model_name == "INP-Former":
         for image_path, mask_path in tqdm(samples, desc=f"Processing {model_name} samples"):
             image = cv2.imread(image_path)
             if image is None:
                 print(f"Warning: Could not load image from {image_path}. Skipping.")
                 continue
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
             image = crop_dark_edges(image)
             heatmap = run_inpformer(image, model_path)
             # Ensure heatmap is 2D before flattening
             if heatmap.ndim > 2:
                  heatmap = heatmap.squeeze()
             all_heatmaps.append(heatmap.flatten())

    elif model_name == "GLASS":
        # For GLASS, we process in batches of 8 as required by the model.
        expected_batch_size = 8
        current_batch_images = []
        # Need to store original indices to correctly extract heatmaps later
        current_batch_original_indices = []

        for i, (image_path, mask_path) in enumerate(tqdm(samples, desc=f"Processing {model_name} samples")):
            image = cv2.imread(image_path)
            if image is None:
                print(f"Warning: Could not load image from {image_path}. Skipping.")
                continue
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = crop_dark_edges(image)
            current_batch_images.append(image)
            current_batch_original_indices.append(i)

            # If we have a full batch or it's the last image
            if len(current_batch_images) == expected_batch_size or i == len(samples) - 1:
                original_batch_size = len(current_batch_images)
                # Pad the batch if it's not full
                while len(current_batch_images) < expected_batch_size:
                    # Add a dummy image (e.g., black image) to fill the batch
                    # Use the shape of the first image in the current batch for consistency
                    if current_batch_images: # Ensure there's at least one real image
                        dummy_image = np.zeros_like(current_batch_images[0])
                    else: # Should not happen if samples is not empty, but as a fallback
                         dummy_image = np.zeros((256, 256, 3), dtype=np.uint8) # Default size
                    current_batch_images.append(dummy_image)

                try:
                     # run_glass now expects a list of images, potentially padded
                     heatmap_batch = run_glass(current_batch_images, model_path)

                     # Extract heatmaps for the *real* images in the batch
                     for j in range(original_batch_size):
                         heatmap = heatmap_batch[j] # Get heatmap for the j-th image in the current batch
                         # Ensure heatmap is 2D before flattening
                         if heatmap.ndim > 2:
                              heatmap = heatmap.squeeze()
                         all_heatmaps.append(heatmap.flatten())

                except Exception as e:
                    # Get the path of the first image in the problematic batch for logging
                    first_image_path_in_batch = samples[current_batch_original_indices[0]][0]
                    print(f"Error processing batch starting with image index {current_batch_original_indices[0]} ({first_image_path_in_batch}): {e}")
                    # Continue to the next batch

                # Reset the batch
                current_batch_images = []
                current_batch_original_indices = []

    else:
        raise ValueError(f"Unsupported model: {model_name}")

    # Calculate global min/max from all collected heatmaps
    if not all_heatmaps:
        print("Warning: No heatmaps were generated. Returning default stats.")
        return 0.0, 1.0

    all_heatmaps_flat = np.concatenate(all_heatmaps)
    global_min = np.min(all_heatmaps_flat)
    global_max = np.max(all_heatmaps_flat)

    return global_min, global_max

### Limits

In [None]:
inpformer_onnx_path = "/content/drive/MyDrive/Neural_Networks_Project/INP-Former/inpformer.onnx"
global_min, global_max = get_global_heatmap_stats(dataset_root, "INP-Former", inpformer_onnx_path)

print(f"Global Min: {global_min}")
print(f"Global Max: {global_max}")


Processing INP-Former samples: 100%|██████████| 10/10 [01:27<00:00,  8.79s/it]

Global Min: 0.06759029626846313
Global Max: 0.5027773380279541





In [None]:
efficientad_weights_folder = "/content/drive/MyDrive/Neural_Networks_Project/EfficientAD/weights"
global_min, global_max = get_global_heatmap_stats(dataset_root, "EfficientAD", efficientad_weights_folder)

print(f"Global Min: {global_min}")
print(f"Global Max: {global_max}")

100%|██████████| 10/10 [03:06<00:00, 18.66s/it]

Global Min: -4.71906042098999
Global Max: -1.2570910453796387





In [None]:
glass_onnx_path = "/content/drive/MyDrive/Neural_Networks_Project/glass.onnx"
global_min, global_max = get_global_heatmap_stats(dataset_root, "GLASS", glass_onnx_path)

print(f"Global Min: {global_min}")
print(f"Global Max: {global_max}")

Processing GLASS samples: 100%|██████████| 10/10 [00:25<00:00,  2.54s/it]

Global Min: 0.1804734468460083
Global Max: 0.9709413051605225





### Results

In [17]:
inpformer_onnx_path = "/content/drive/MyDrive/Neural_Networks_Project/INP-Former/inpformer.onnx"
dataset_root = "/content/drive/MyDrive/Neural_Networks_Project/wood_dataset/wood"
evaluate_image_level(dataset_root, "INP-Former", inpformer_onnx_path, threshold=0.29, global_min = 0.0, global_max = 0.5)

Output hidden; open in https://colab.research.google.com to view.

In [19]:
efficientad_weights_folder = "/content/drive/MyDrive/Neural_Networks_Project/EfficientAD/weights"
dataset_root = "/content/drive/MyDrive/Neural_Networks_Project/wood_dataset/wood"
evaluate_image_level(dataset_root, "EfficientAD", efficientad_weights_folder, threshold= -1.9 , global_min = -4.75, global_max = -1.25)

Output hidden; open in https://colab.research.google.com to view.

In [16]:
glass_onnx_path = "/content/drive/MyDrive/Neural_Networks_Project/glass.onnx"
dataset_root = "/content/drive/MyDrive/Neural_Networks_Project/wood_dataset/wood"
evaluate_image_level(dataset_root, "GLASS", glass_onnx_path, threshold= 0.7 , global_min = 0.15, global_max = 1)

Output hidden; open in https://colab.research.google.com to view.