0. Load libraries

In [89]:
import tensorflow as tf
import os
import sys
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
import cv2
import numpy as np
from PIL import Image
from pathlib import Path

1) Load the segmentation model

In [None]:
# Model path
data_root = "/mnt/c/Users/pdeschepper/Desktop/PERSONAL/DeepLearning/ImageSegmentation/Snakes_ImageSegmentation_keras/Vipera_segmentation_train_dataset/"
model_filename = "Vipera_SegmentationModel_V2_100+50epochs.keras"
model_path = os.path.join(data_root, model_filename)

# Import custom functions from the accessory script
sys.path.append("/mnt/c/Users/pdeschepper/Desktop/PERSONAL/DeepLearning/ImageSegmentation/Snakes_ImageSegmentation_keras/")
from HelperFuncs_IoU_DiceLoss_CombinedLoss import combined_loss, dice_loss, mean_iou

# Load model
model = load_model(
    model_path,
    custom_objects={
        'combined_loss': combined_loss,
        'dice_loss': dice_loss,
        'mean_iou': mean_iou
    } 
)

print("✅ Model loaded successfully!")
model.summary()
model.output_shape

✅ Model loaded successfully!
[1mModel: "functional_2"[0m
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mLayer (type)       [0m[1m [0m┃[1m [0m[1mOutput Shape     [0m[1m [0m┃[1m [0m[1m   Param #[0m[1m [0m┃[1m [0m[1mConnected to     [0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input_layer_2       │ ([96mNone[0m, [32m512[0m, [32m512[0m,  │          [32m0[0m │ -                 │
│ ([94mInputLayer[0m)        │ [32m3[0m)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ rescaling           │ ([96mNone[0m, [32m512[0m, [32m512[0m,  │          [32m0[0m │ input_layer_2[[32m0[0m]… │
│ ([94mRescaling[0m)         │ [32m3[0m)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization       │ ([96mNone

  saveable.load_own_variables(weights_store.get(inner_path))


(None, 512, 512, 2)

2. Segmenting and Extracting the Snakes


In [93]:
# Configuration
image_folder = r"/mnt/c/Users/pdeschepper/Desktop/PERSONAL/DeepLearning/ImageSegmentation/Snakes_ImageSegmentation_keras/Vipera_segmentation_test_dataset/"
output_folder = os.path.join(image_folder, "Extracted_snakes")

# Create output folder with explicit error handling
os.makedirs(output_folder, exist_ok=True)
print(f"Output folder created/verified: {output_folder}\n")

# Get list of image files
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
image_files = [f for f in os.listdir(image_folder) 
               if os.path.splitext(f)[1].lower() in image_extensions
               and not os.path.isdir(os.path.join(image_folder, f))]

print(f"Found {len(image_files)} images to process\n")

for idx, img_name in enumerate(image_files[:100], 1):
    try:
        img_path = os.path.join(image_folder, img_name)
        
        # Load original image for saving later
        img_original = Image.open(img_path).convert('RGB')
        original_size = (img_original.width, img_original.height)
        img_array = np.array(img_original)
        
        # Load and preprocess image the same way as training data
        img_resized = img_original.resize((512, 512))
        img_input = np.array(img_resized, dtype=np.uint8)  # Keep as uint8, no normalization!
        
        # Add batch dimension
        img_input = np.expand_dims(img_input, axis=0)
        
        print(f"[{idx}] Processing {img_name}...", end=" ")
        
        # Get prediction
        mask_output = model.predict(img_input, verbose=0)  # Shape: (1, 512, 512, 2)
        
        # Apply argmax like in your visualization code
        mask_pred = tf.argmax(mask_output, axis=-1).numpy()  # Shape: (1, 512, 512)
        mask_pred = mask_pred[0]  # Remove batch dimension
        
        # Check mask coverage
        mask_percentage = (np.sum(mask_pred == 1) / mask_pred.size) * 100
        print(f"(class 1 coverage: {mask_percentage:.1f}%)", end=" ")
        
        # Resize mask back to original size
        mask_pil = Image.fromarray((mask_pred * 255).astype(np.uint8))
        mask_pil = mask_pil.resize(original_size, Image.NEAREST)
        mask_array = np.array(mask_pil).astype(np.float32) / 255.0
        
        # Apply mask to extract feature
        mask_expanded = np.stack([mask_array] * 3, axis=2)
        extracted_rgb = (img_array * mask_expanded).astype(np.uint8)
        
        # Create RGBA image with transparency
        extracted = np.zeros((original_size[1], original_size[0], 4), dtype=np.uint8)
        extracted[:, :, :3] = extracted_rgb
        extracted[:, :, 3] = (mask_array * 255).astype(np.uint8)
        
        # Save
        output_name = os.path.splitext(img_name)[0] + '.png'
        output_path = os.path.join(output_folder, output_name)
        result_img = Image.fromarray(extracted, 'RGBA')
        result_img.save(output_path)
        
        print(f"✓ Saved")
        
    except Exception as e:
        print(f"ERROR processing {img_name}: {e}")
        continue

print(f"\nProcessing complete! Results saved to: {output_folder}")

Output folder created/verified: /mnt/c/Users/pdeschepper/Desktop/PERSONAL/DeepLearning/ImageSegmentation/Snakes_ImageSegmentation_keras/Vipera_segmentation_test_dataset/Extracted_snakes

Found 1998 images to process

[1] Processing 101346513.jpg... (class 1 coverage: 4.0%) ✓ Saved
[2] Processing 101347110.jpg... (class 1 coverage: 3.6%) ✓ Saved
[3] Processing 101450839.jpg... (class 1 coverage: 3.6%) ✓ Saved
[4] Processing 101450848.jpg... (class 1 coverage: 3.4%) ✓ Saved
[5] Processing 103439081.jpg... (class 1 coverage: 1.6%) ✓ Saved
[6] Processing 104032787.jpg... (class 1 coverage: 4.5%) ✓ Saved
[7] Processing 104347213.jpg... (class 1 coverage: 33.5%) ✓ Saved
[8] Processing 104347312.jpg... (class 1 coverage: 16.4%) ✓ Saved
[9] Processing 104347469.jpg... (class 1 coverage: 23.6%) ✓ Saved
[10] Processing 106144681.jpg... (class 1 coverage: 3.5%) ✓ Saved
[11] Processing 10705266.jpg... (class 1 coverage: 18.1%) ✓ Saved
[12] Processing 10867733.jpg... (class 1 coverage: 1.2%) ✓ Save