In [2]:
import torch
from torchvision import models, transforms
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import glob

# Load pre-trained ResNet18 model for feature extraction
model = models.resnet18(pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))  # Remove final classification layer
model.eval()

# Transformation for the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Feature extraction function
def extract_features(image_path):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        features = model(image).flatten().numpy()
    return features

# Paths to images (example paths, update with actual paths to your images)
other_images = glob.glob("/raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/data/BigBag4_7_Other/images_cutout/*.jpg")  # Replace with your "Other" images directory
pet_images = glob.glob("/raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/data/BigBag4_1_PET/images_cutout/*.jpg")  # Replace with your PET images directory

# Extract features
other_features = [(img_path, extract_features(img_path)) for img_path in other_images[:5]]  # Limit to 5 for simplicity
pet_features = [(img_path, extract_features(img_path)) for img_path in pet_images[:5]]  # Limit to 5 for simplicity

# Calculate similarities and find closest pairs
similar_pairs = []
for other_img_path, other_feat in other_features:
    for pet_img_path, pet_feat in pet_features:
        similarity = cosine_similarity([other_feat], [pet_feat])[0][0]
        similar_pairs.append((similarity, other_img_path, pet_img_path))

# Sort pairs by similarity (highest first) and select top 1 or 2 pairs
similar_pairs = sorted(similar_pairs, reverse=True)[:2]

# Display results
for similarity, other_img_path, pet_img_path in similar_pairs:
    print(f"Similarity: {similarity:.2f}")
    print(f"Other Image: {other_img_path}")
    print(f"PET Image: {pet_img_path}")


Similarity: 0.73
Other Image: /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/data/BigBag4_7_Other/images_cutout/4994_3_99.jpg
PET Image: /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/data/BigBag4_1_PET/images_cutout/6127_1_98.jpg
Similarity: 0.70
Other Image: /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/data/BigBag4_7_Other/images_cutout/4994_3_99.jpg
PET Image: /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/data/BigBag4_1_PET/images_cutout/6622_1_99.jpg
