In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
import numpy as np
import cv2
from typing import Tuple, Union
import os
from pathlib import Path

In [None]:
class PerceptualFeatureExtractor(nn.Module):
    """
    Feature extractor based on ResNet18 with additional perceptual layers
    to capture structural information while being robust to lighting changes.
    """
    def __init__(self, pretrained=True):
        super(PerceptualFeatureExtractor, self).__init__()

        # Use pre-trained ResNet18 as backbone
        if pretrained:
            self.backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        else:
            self.backbone = resnet18(weights=None)

        # Remove the final classification layer
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])

        # Add custom layers for perceptual features
        self.perceptual_head = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )

        # Gradient-based edge detector (Sobel-like)
        self.register_buffer('sobel_x', torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3))
        self.register_buffer('sobel_y', torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3))

    def extract_edge_features(self, x):
        """Extract edge features that are robust to lighting changes"""
        # Convert to grayscale
        gray = torch.mean(x, dim=1, keepdim=True)

        # Apply Sobel filters
        edges_x = F.conv2d(gray, self.sobel_x, padding=1)
        edges_y = F.conv2d(gray, self.sobel_y, padding=1)

        # Combine edge information
        edges = torch.sqrt(edges_x**2 + edges_y**2)
        return edges

    def forward(self, x):
        # Extract deep features
        deep_features = self.backbone(x)
        deep_features = self.perceptual_head(deep_features)

        # Extract edge features
        edge_features = self.extract_edge_features(x)
        edge_features = F.adaptive_avg_pool2d(edge_features, (8, 8))
        edge_features = edge_features.flatten(1)

        # Combine features
        combined = torch.cat([deep_features, edge_features], dim=1)

        return combined

In [None]:
class SiameseNetwork(nn.Module):
    """
    Siamese network for image similarity detection with focus on structural changes
    while being robust to lighting and compression artifacts.
    """
    def __init__(self, pretrained=True):
        super(SiameseNetwork, self).__init__()

        self.feature_extractor = PerceptualFeatureExtractor(pretrained=pretrained)

        # Calculate feature dimension (64 from perceptual head + 64 from edge features)
        feature_dim = 64 + 64  # 128 total

        # Similarity head
        self.similarity_head = nn.Sequential(
            nn.Linear(feature_dim * 2, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward_once(self, x):
        """Forward pass for one image"""
        return self.feature_extractor(x)

    def forward(self, img1, img2):
        """Forward pass for image pair"""
        features1 = self.forward_once(img1)
        features2 = self.forward_once(img2)

        # Concatenate features
        combined = torch.cat([features1, features2], dim=1)

        # Compute similarity score
        similarity = self.similarity_head(combined)

        return similarity

In [None]:
class ImageSimilarityDetector:
    """
    Main class for image similarity detection with preprocessing and inference pipeline.
    """
    def __init__(self, model_path=None, device=None):
        if device is None:
            self.device = torch.device('cpu')  # CPU-only as specified
        else:
            self.device = device

        self.model = SiameseNetwork(pretrained=True)
        self.model.to(self.device)

        if model_path and os.path.exists(model_path):
            self.load_model(model_path)

        # Image preprocessing pipeline
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

        # Threshold for similarity decision
        self.similarity_threshold = 0.5

    def preprocess_image(self, image_input: Union[str, np.ndarray, Image.Image]) -> torch.Tensor:
        """
        Preprocess image for model input with additional noise reduction.
        """
        # Handle different input types
        if isinstance(image_input, str):
            image = Image.open(image_input).convert('RGB')
        elif isinstance(image_input, np.ndarray):
            if image_input.shape[2] == 3:  # BGR to RGB
                image_input = cv2.cvtColor(image_input, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(image_input)
        elif isinstance(image_input, Image.Image):
            image = image_input.convert('RGB')
        else:
            raise ValueError("Unsupported image input type")

        # Apply noise reduction (optional - helps with compression artifacts)
        img_array = np.array(image)
        img_array = cv2.bilateralFilter(img_array, 9, 75, 75)
        image = Image.fromarray(img_array)

        # Apply transforms
        tensor = self.transform(image)
        return tensor.unsqueeze(0)  # Add batch dimension

    def extract_structural_features(self, image_tensor: torch.Tensor) -> dict:
        """
        Extract additional structural features for robustness analysis.
        """
        # Convert to numpy for OpenCV operations
        img_np = image_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
        img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255
        img_np = np.clip(img_np, 0, 255).astype(np.uint8)

        gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)

        # Extract keypoints and descriptors
        orb = cv2.ORB_create(nfeatures=100)
        keypoints, descriptors = orb.detectAndCompute(gray, None)

        # Compute image statistics
        hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
        hist = hist.flatten() / hist.sum()  # Normalize

        return {
            'keypoints': len(keypoints) if keypoints else 0,
            'descriptors': descriptors,
            'histogram': hist
        }

    def compute_similarity(self, img1_path: Union[str, np.ndarray, Image.Image],
                          img2_path: Union[str, np.ndarray, Image.Image],
                          use_structural_validation=True) -> Tuple[bool, float, dict]:
        """
        Main inference function to determine if two images are similar.

        Returns:
            - bool: True if images are similar, False if different
            - float: Raw similarity score (0-1)
            - dict: Additional analysis metrics
        """
        self.model.eval()

        with torch.no_grad():
            # Preprocess images
            img1_tensor = self.preprocess_image(img1_path).to(self.device)
            img2_tensor = self.preprocess_image(img2_path).to(self.device)

            # Get model prediction
            similarity_score = self.model(img1_tensor, img2_tensor).item()

            # Additional structural validation
            analysis_metrics = {}
            if use_structural_validation:
                features1 = self.extract_structural_features(img1_tensor)
                features2 = self.extract_structural_features(img2_tensor)

                # Keypoint difference analysis
                keypoint_diff = abs(features1['keypoints'] - features2['keypoints'])
                keypoint_ratio = keypoint_diff / max(features1['keypoints'], features2['keypoints'], 1)

                # Histogram correlation (lighting robustness check)
                hist_corr = cv2.compareHist(features1['histogram'], features2['histogram'], cv2.HISTCMP_CORREL)

                analysis_metrics = {
                    'keypoint_difference': keypoint_diff,
                    'keypoint_ratio': keypoint_ratio,
                    'histogram_correlation': hist_corr,
                    'raw_similarity': similarity_score
                }

                # Adjust threshold based on histogram correlation
                # If histograms are very similar (lighting similar), use stricter threshold
                if hist_corr > 0.8:
                    adjusted_threshold = self.similarity_threshold + 0.1
                else:
                    adjusted_threshold = self.similarity_threshold

                is_similar = similarity_score > adjusted_threshold
            else:
                is_similar = similarity_score > self.similarity_threshold
                analysis_metrics = {'raw_similarity': similarity_score}

            return is_similar, similarity_score, analysis_metrics

    def set_threshold(self, threshold: float):
        """Set similarity threshold (0-1)"""
        self.similarity_threshold = max(0.0, min(1.0, threshold))

    def save_model(self, path: str):
        """Save trained model"""
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'threshold': self.similarity_threshold
        }, path)
        print(f"Model saved to {path}")

    def load_model(self, path: str):
        """Load trained model"""
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        if 'threshold' in checkpoint:
            self.similarity_threshold = checkpoint['threshold']
        print(f"Model loaded from {path}")

    def train_model(self, train_dataloader, val_dataloader=None, epochs=20, lr=0.001):
        """
        Training function for the model.
        Expects dataloader with (img1, img2, label) where label=1 for similar, 0 for different.
        """
        self.model.train()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
        criterion = nn.BCELoss()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)

        best_val_loss = float('inf')

        for epoch in range(epochs):
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            for batch_idx, (img1, img2, labels) in enumerate(train_dataloader):
                img1, img2, labels = img1.to(self.device), img2.to(self.device), labels.float().to(self.device)

                optimizer.zero_grad()
                outputs = self.model(img1, img2).squeeze()
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                predicted = (outputs > 0.5).float()
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()

                if batch_idx % 10 == 0:
                    print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}')

            avg_train_loss = train_loss / len(train_dataloader)
            train_acc = 100. * train_correct / train_total

            print(f'Epoch {epoch+1}/{epochs}:')
            print(f'  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%')

            # Validation
            if val_dataloader:
                val_loss, val_acc = self.validate(val_dataloader, criterion)
                print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
                scheduler.step(val_loss)

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    self.save_model('best_model.pth')

            print('-' * 50)

    def validate(self, val_dataloader, criterion):
        """Validation function"""
        self.model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for img1, img2, labels in val_dataloader:
                img1, img2, labels = img1.to(self.device), img2.to(self.device), labels.float().to(self.device)
                outputs = self.model(img1, img2).squeeze()
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                predicted = (outputs > 0.5).float()
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        self.model.train()
        return val_loss / len(val_dataloader), 100. * val_correct / val_total

In [None]:
# Utility functions for easy integration
def create_detector(model_path=None):
    """Factory function to create detector instance"""
    return ImageSimilarityDetector(model_path=model_path)

def detect_image_similarity(img1_path, img2_path, model_path=None, threshold=0.5):
    """
    Convenience function for one-off similarity detection.

    Args:
        img1_path: Path to first image or image array/PIL Image
        img2_path: Path to second image or image array/PIL Image
        model_path: Path to trained model (optional)
        threshold: Similarity threshold (default: 0.5)

    Returns:
        bool: True if images are similar, False if different
    """
    detector = ImageSimilarityDetector(model_path=model_path)
    detector.set_threshold(threshold)
    is_similar, score, metrics = detector.compute_similarity(img1_path, img2_path)
    return is_similar

In [None]:
# Example usage
detector = ImageSimilarityDetector()

# Example: Compare two images
# is_similar, score, metrics = detector.compute_similarity('image1.jpg', 'image2.jpg')
# print(f"Images are similar: {is_similar}")
# print(f"Similarity score: {score:.3f}")
# print(f"Analysis metrics: {metrics}")

print("Image Similarity Detector initialized successfully!")
print("Model is ready for CPU inference.")
print("\nKey features:")
print("- Robust to lighting changes and compression artifacts")
print("- Detects meaningful structural changes")
print("- CPU-optimized for production use")
print("- Easy integration with existing Python systems")