In [10]:
# Import necessary libraries
import os
import glob
import random
import pandas as pd
import numpy as np
import requests
from io import BytesIO
from PIL import Image, ImageFilter
import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend suitable for script running
import matplotlib.pyplot as plt
from tqdm import tqdm  # For progress bars
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.nn as nn

import faiss

# Additional imports for label encoding
from sklearn.preprocessing import LabelEncoder

# Import traceback for detailed error messages
import traceback

# Step 0: Define Helper Functions
print("Step 0: Defining helper functions...")

def show_results(query_image_path, results, query_index, transformed=False):
    """
    Displays the query image alongside its top matching results and saves the figure.

    Parameters:
    - query_image_path (str): Path or URL of the query image.
    - results (pd.DataFrame): DataFrame containing the top matching artworks.
    - query_index (int): Index of the query image.
    - transformed (bool): Indicates if the query image was transformed.

    Returns:
    - None
    """
    print("\nDisplaying results visually...")
    try:
        if query_image_path.startswith('http'):
            response = requests.get(query_image_path, timeout=10)
            query_img = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            query_img = Image.open(query_image_path).convert('RGB')
    except Exception as e:
        print(f"Failed to load query image: {e}")
        return

    if transformed:
        # Visual indication that the query image was transformed
        query_img = query_img.filter(ImageFilter.GaussianBlur(radius=1))
        query_img = query_img.rotate(10)  # Example transformation

    num_results = min(len(results), 5)
    plt.figure(figsize=(5 * (num_results + 1), 5))

    # Display Query Image
    plt.subplot(1, num_results + 1, 1)
    plt.imshow(query_img)
    title = 'Query Image'
    if transformed:
        title += ' (Transformed)'
    plt.title(title)
    plt.axis('off')

    # Display Matching Images
    for i in range(num_results):
        img_path = results.iloc[i]['image_path']
        try:
            img = Image.open(img_path).convert('RGB')
            plt.subplot(1, num_results + 1, i + 2)
            plt.imshow(img)
            plt.title(f"Match {i+1}")
            plt.axis('off')
            print(f"Loaded Match {i+1}: {img_path}")
        except Exception as e:
            print(f"Failed to load image {img_path}: {e}")

    output_dir = 'results'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    output_path = os.path.join(output_dir, f'query_results_{query_index}.png')
    plt.savefig(output_path)
    plt.close()
    print(f"Results displayed and saved to '{output_path}'.")

def process_input_image(image_path_or_url, transform, device):
    """
    Processes the input image by loading, transforming, and sending it to the device.

    Parameters:
    - image_path_or_url (str): Path or URL of the image.
    - transform (torchvision.transforms.Compose): Transformations to apply.
    - device (torch.device): Device to send the image tensor.

    Returns:
    - torch.Tensor: Processed image tensor.
    """
    print(f"  Processing input image: {image_path_or_url}")
    try:
        if image_path_or_url.startswith('http'):
            response = requests.get(image_path_or_url, timeout=10)
            img = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            img = Image.open(image_path_or_url).convert('RGB')
    except Exception as e:
        print(f"  Failed to load image {image_path_or_url}: {e}")
        # Return a black image in case of error
        img = Image.new('RGB', (224, 224), (0, 0, 0))
    try:
        img = transform(img).unsqueeze(0).to(device, non_blocking=True)
        print("  Image processed and transformed.")
    except Exception as e:
        print(f"  Failed to transform image {image_path_or_url}: {e}")
        # Return a tensor of zeros in case of transformation failure
        img = torch.zeros((1, 3, 224, 224)).to(device)
    return img

def get_image_features(model, img_tensor):
    """
    Extracts and normalizes features from the image tensor.

    Parameters:
    - model (torch.nn.Module): The feature extraction model.
    - img_tensor (torch.Tensor): Image tensor.

    Returns:
    - np.ndarray: Normalized feature vector.
    """
    print("  Extracting features from the query image...")
    try:
        with torch.no_grad():
            features = model(img_tensor)
            features = features.view(features.size(0), -1)  # Flatten
            features = features / features.norm(dim=1, keepdim=True)
        print("  Features extracted and normalized.")
        return features.cpu().numpy().astype('float32')
    except Exception as e:
        print(f"  Failed to extract features: {e}")
        traceback.print_exc()
        return None

def search_artwork(query_features, index, k=5):
    """
    Searches for the top k similar artworks using FAISS.

    Parameters:
    - query_features (np.ndarray): Feature vector of the query image.
    - index (faiss.Index): FAISS index.
    - k (int): Number of top matches to retrieve.

    Returns:
    - distances (np.ndarray): Similarity scores.
    - indices (np.ndarray): Indices of the top matches.
    """
    print(f"  Searching for top {k} similar artworks...")
    try:
        distances, indices = index.search(query_features, k)
        print("  Search completed.")
        return distances[0], indices[0]
    except Exception as e:
        print(f"  FAISS search failed: {e}")
        traceback.print_exc()
        return None, None

def get_artwork_info(indices, image_df):
    """
    Retrieves artwork information based on the indices.

    Parameters:
    - indices (np.ndarray): Indices of the top matches.
    - image_df (pd.DataFrame): DataFrame containing image metadata.

    Returns:
    - pd.DataFrame: DataFrame of the top matching artworks.
    """
    print("  Retrieving artwork information for the top matches...")
    try:
        results = image_df.iloc[indices].reset_index(drop=True)
        print("  Artwork information retrieved.")
        return results
    except Exception as e:
        print(f"  Error retrieving artwork information: {e}")
        traceback.print_exc()
        return pd.DataFrame()  # Return empty DataFrame on error

def identify_artwork(image_path_or_url, model, index, image_df, transform, device, k=5):
    """
    Identifies the top k artworks similar to the query image.

    Parameters:
    - image_path_or_url (str): Path or URL of the query image.
    - model (torch.nn.Module): Feature extraction model.
    - index (faiss.Index): FAISS index.
    - image_df (pd.DataFrame): DataFrame containing image metadata.
    - transform (torchvision.transforms.Compose): Transformations to apply.
    - device (torch.device): Device to send the image tensor.
    - k (int): Number of top matches to retrieve.

    Returns:
    - results (pd.DataFrame): DataFrame of the top matching artworks.
    - distances (np.ndarray): Similarity scores.
    """
    print("\nIdentifying artwork...")
    img_tensor = process_input_image(image_path_or_url, transform, device)
    if img_tensor is None:
        print("  Image tensor is None. Skipping identification.")
        return pd.DataFrame(), np.array([])
    
    query_features = get_image_features(model, img_tensor)
    if query_features is None:
        print("  Query features are None. Skipping identification.")
        return pd.DataFrame(), np.array([])
    
    distances, indices = search_artwork(query_features, index, k)
    if indices is None:
        print("  Indices are None. Skipping retrieval.")
        return pd.DataFrame(), np.array([])
    
    results = get_artwork_info(indices, image_df)
    print("Artwork identification completed.")
    return results, distances

def extract_features(model, dataloader, device):
    """
    Extracts features from images using the provided model and dataloader.

    Parameters:
    - model (torch.nn.Module): The feature extraction model.
    - dataloader (DataLoader): DataLoader for the dataset.
    - device (torch.device): Device to perform computations on.

    Returns:
    - torch.Tensor: Extracted features.
    """
    print("Starting feature extraction...")
    features = []
    with torch.no_grad():
        for batch_idx, (imgs, _) in enumerate(tqdm(dataloader, desc="Extracting features"), 1):
            try:
                imgs = imgs.to(device, non_blocking=True)
                outputs = model(imgs)
                outputs = outputs.view(outputs.size(0), -1)  # Flatten to (batch_size, feature_dim)
                outputs = outputs / outputs.norm(dim=1, keepdim=True)
                features.append(outputs.cpu())
            except Exception as e:
                print(f"Error processing batch {batch_idx}: {e}")
                traceback.print_exc()
    if features:
        features = torch.cat(features, dim=0)
    else:
        features = torch.tensor([])
    print("Feature extraction completed.")
    return features

# Step 1: Load and Prepare the WikiArt Dataset (All Classes)
print("\nStep 1: Loading and preparing the WikiArt dataset (all classes)...")

# Specify the path to the WikiArt dataset
wikiart_dir = '../../scratch/mexas.v'  # Update this path as necessary

# Check if the dataset directory exists
if not os.path.exists(wikiart_dir):
    raise FileNotFoundError(f"WikiArt directory '{wikiart_dir}' does not exist. Please update the path.")

# Get the list of class names (styles, artists, or genres)
classes = [d for d in os.listdir(wikiart_dir) if os.path.isdir(os.path.join(wikiart_dir, d))]
print(f"Found {len(classes)} classes in the dataset.")

# Prepare a list to hold image paths and labels
data = []

for label in classes:
    # Get all image paths for the current class
    image_paths = glob.glob(os.path.join(wikiart_dir, label, '*.jpg'))
    print(f"  Found {len(image_paths)} images for class '{label}'.")
    for path in image_paths:
        data.append({
            'image_path': path,
            'label': label
        })

# Create a DataFrame
image_df = pd.DataFrame(data)
print(f"Prepared image_df with {len(image_df)} entries from all classes.")

# Encode labels
label_encoder = LabelEncoder()
image_df['encoded_label'] = label_encoder.fit_transform(image_df['label'])

# Step 2: Define Image Transformations
print("\nStep 2: Defining image transformations...")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

print("Image transformations defined.")

# Step 3: Create Custom Dataset and DataLoader
print("\nStep 3: Creating custom dataset and data loader...")

class ArtImageDataset(Dataset):
    def __init__(self, image_df, transform=None):
        self.image_df = image_df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.image_df)

    def __getitem__(self, idx):
        img_path = self.image_df.loc[idx, 'image_path']
        label = self.image_df.loc[idx, 'encoded_label']
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), (0, 0, 0))
        if self.transform:
            try:
                img = self.transform(img)
            except Exception as e:
                print(f"Error transforming image {img_path}: {e}")
                img = torch.zeros((3, 224, 224))  # Fallback to a tensor of zeros
        return img, idx  # Return index to keep track

# Create dataset and DataLoader
batch_size = 16  # Adjust based on your GPU memory
dataset = ArtImageDataset(image_df, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True)

print(f"DataLoader created with {len(dataloader)} batches.")

# Step 4: Set Up Device and Load Pre-trained Model
print("\nStep 4: Setting up the device and loading the pre-trained model...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the pre-trained ResNet-50 model
model = models.resnet50(pretrained=True)

# Remove the final classification layer for feature extraction
model = nn.Sequential(*list(model.children())[:-1])
model = model.to(device)
model.eval()

print("Pre-trained model loaded and modified for feature extraction.")
# Step 5: Extract Features with Conditional Loading
print("\nStep 5: Extracting features from all images...")

# Define the path to save/load features
features_file = 'wikiart_features.npy'

if os.path.exists(features_file):
    print(f"Feature file '{features_file}' found. Loading existing features...")
    try:
        features_np = np.load(features_file)
        print(f"Loaded features with shape: {features_np.shape}")
    except Exception as e:
        print(f"Error loading features from '{features_file}': {e}")
        traceback.print_exc()
        print("Proceeding to extract features anew.")
        
        # Extract features as the file could not be loaded
        features = extract_features(model, dataloader, device)
        print(f"Extracted features shape: {features.shape}")
        
        if features.numel() > 0:
            # Convert features to numpy array
            features_np = features.cpu().numpy().astype('float32')
            
            # Save features
            np.save(features_file, features_np)
            print(f"Features extracted and saved to '{features_file}'.")
        else:
            print("No features extracted. Exiting.")
            exit(1)
else:
    print(f"Feature file '{features_file}' not found. Extracting features...")
    
    # Extract features
    features = extract_features(model, dataloader, device)
    print(f"Extracted features shape: {features.shape}")
    
    if features.numel() > 0:
        # Convert features to numpy array
        features_np = features.cpu().numpy().astype('float32')
        
        # Save features
        np.save(features_file, features_np)
        print(f"Features extracted and saved to '{features_file}'.")
    else:
        print("No features extracted. Exiting.")
        exit(1)

# Step 6: Create FAISS Index and Add Features
print("\nStep 6: Creating FAISS index with extracted features...")

# Verify that feature dimensions match
feature_dim = features_np.shape[1]
print(f"Feature dimension: {feature_dim}")

try:
    # Create FAISS index for Inner Product (cosine similarity if features are normalized)
    index = faiss.IndexFlatIP(feature_dim)
    print("FAISS IndexFlatIP created.")
    
    # Add features to the index
    index.add(features_np)
    print(f"Added {index.ntotal} features to the FAISS index.")
    
    # Save the FAISS index
    faiss_index_path = 'wikiart_faiss_index.bin'
    faiss.write_index(index, faiss_index_path)
    print(f"FAISS index saved to '{faiss_index_path}'.")
except Exception as e:
    print(f"Failed to create or save FAISS index: {e}")
    traceback.print_exc()
    exit(1)

# # Step 7: Perform Queries on 100 Random Obfuscated Images
# print("\nStep 7: Performing queries on 100 random obfuscated images...")

# num_queries = 100  # Number of random images to query
# k = 5  # Number of top matches to retrieve

# # Ensure the dataset has enough images
# if num_queries > len(image_df):
#     num_queries = len(image_df)
#     print(f"Only {num_queries} images available. Adjusting number of queries to {num_queries}.")

# # Select 100 random indices from the dataset
# random_indices = random.sample(range(len(image_df)), num_queries)

# # Initialize counters
# correct_at_1 = 0  # Top-1 accuracy
# correct_at_k = 0  # Top-k accuracy

# # Directory to save transformed query images
# transformed_dir = 'transformed_queries'
# if not os.path.exists(transformed_dir):
#     os.makedirs(transformed_dir)

# for idx, query_image_index in enumerate(tqdm(random_indices, desc="Processing Queries"), 1):
#     print(f"\nQuery {idx}/{num_queries}")
#     original_image_path = image_df.iloc[query_image_index]['image_path']
#     original_label = image_df.iloc[query_image_index]['encoded_label']
#     image_filename = os.path.basename(original_image_path)
#     transformed_image_path = os.path.join(transformed_dir, f"transformed_{image_filename}")
    
#     # Apply obfuscations: rotation and blur
#     try:
#         img = Image.open(original_image_path).convert('RGB')
#         # Random rotation between -15 and +15 degrees
#         rotation_angle = random.uniform(-15, 15)
#         img = img.rotate(rotation_angle)
#         # Random Gaussian blur with radius between 0 and 2
#         blur_radius = random.uniform(0, 2)
#         img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
#         img.save(transformed_image_path)
#         print(f"  Applied rotation of {rotation_angle:.2f} degrees and blur radius of {blur_radius:.2f}.")
#     except Exception as e:
#         print(f"  Error transforming image {original_image_path}: {e}")
#         traceback.print_exc()
#         continue
    
#     # Perform retrieval using the obfuscated image
#     try:
#         results, distances = identify_artwork(
#             transformed_image_path,
#             model,
#             index,
#             image_df,
#             transform,
#             device,
#             k=k
#         )
        
#         if results.empty or distances.size == 0:
#             print(f"  No results returned for query {idx}. Skipping accuracy checks.")
#             continue
#     except Exception as e:
#         print(f"  Error identifying artwork for {transformed_image_path}: {e}")
#         traceback.print_exc()
#         continue
    
#     # Check if the correct match is in the top k
#     retrieved_labels = results['encoded_label'].values
#     retrieved_image_paths = results['image_path'].values
    
#     # Check for Top-1 accuracy
#     if retrieved_labels[0] == original_label:
#         correct_at_1 += 1
    
#     # Check for Top-k accuracy
#     if original_label in retrieved_labels:
#         correct_at_k += 1
    
#     # Display and save results
#     show_results(transformed_image_path, results, query_index=idx, transformed=True)

# # Step 8: Summary of Evaluation Results
# print("\nStep 8: Summary of Evaluation Results:")
# print(f"Total Queries: {num_queries}")
# print(f"Top-1 Correct Matches: {correct_at_1}")
# print(f"Top-{k} Correct Matches: {correct_at_k}")
# if num_queries > 0:
#     top1_accuracy = (correct_at_1 / num_queries) * 100
#     topk_accuracy = (correct_at_k / num_queries) * 100
#     print(f"Top-1 Accuracy: {top1_accuracy:.2f}%")
#     print(f"Top-{k} Accuracy: {topk_accuracy:.2f}%")
# else:
#     print("No valid queries were performed.")


# Step 7: Perform Queries on User-Specified Images
print("\nStep 7: Performing queries on user-specified images...")

# Define a list of images to query
# You can add image paths or URLs to this list
query_images = [
    "starry.png",
]

correct_matches = 0  # Counter for correct matches
valid_queries = 0    # Counter for queries where the image is in the dataset

for idx, image_to_query in enumerate(query_images, 1):
    print(f"\nQuery {idx}/{len(query_images)}")
    print(f"Selected image for querying: {image_to_query}")

    try:
        # Use the existing identify_artwork function to get results
        results, distances = identify_artwork(
            image_to_query,
            model,
            index,
            image_df,
            transform,
            device,
            k=5
        )
    except Exception as e:
        print(f"Error identifying artwork: {e}")
        continue

    # Check if the query image is part of the dataset
    is_in_dataset = image_to_query in image_df['image_path'].values

    if is_in_dataset:
        valid_queries += 1
        # Retrieve the top match image path
        top_match_image_path = results.iloc[0]['image_path']
        if top_match_image_path == image_to_query:
            correct_matches += 1
            match_status = "Correct Match"
        else:
            match_status = "Incorrect Match"
    else:
        match_status = "Top match is not the query image (query image not in dataset)"

    # Display Top Matching Artworks
    print(f"\nTop matching artworks for Query {idx}:")
    for i, (result_idx, dist) in enumerate(zip(results.index, distances)):
        row = results.iloc[i]
        print(f"Rank {i+1}:")
        print(f"  Label: {row['label']}")
        print(f"  Similarity Score: {dist:.4f}")
        print(f"  Image Path: {row['image_path']}\n")

    # Display and save results
    show_results(image_to_query, results, query_index=idx)

    print(f"Query {idx} result: {match_status}")

# Step 8: Print Summary Statistics
print("\nSummary:")
print(f"Total Queries: {len(query_images)}")
print(f"Valid Queries (Images in Dataset): {valid_queries}")
print(f"Correct Matches: {correct_matches}")
if valid_queries > 0:
    accuracy = (correct_matches / valid_queries) * 100
    print(f"Accuracy: {accuracy:.2f}%")
else:
    print("No valid queries (no query images were found in the dataset).")



Step 0: Defining helper functions...

Step 1: Loading and preparing the WikiArt dataset (all classes)...
Found 28 classes in the dataset.
  Found 2782 images for class 'Abstract_Expressionism'.
  Found 98 images for class 'Action_painting'.
  Found 110 images for class 'Analytical_Cubism'.
  Found 4334 images for class 'Art_Nouveau_Modern'.
  Found 4240 images for class 'Baroque'.
  Found 1615 images for class 'Color_Field_Painting'.
  Found 481 images for class 'Contemporary_Realism'.
  Found 2235 images for class 'Cubism'.
  Found 1391 images for class 'Early_Renaissance'.
  Found 6736 images for class 'Expressionism'.
  Found 934 images for class 'Fauvism'.
  Found 1343 images for class 'High_Renaissance'.
  Found 13060 images for class 'Impressionism'.
  Found 1279 images for class 'Mannerism_Late_Renaissance'.
  Found 1337 images for class 'Minimalism'.
  Found 2405 images for class 'Naive_Art_Primitivism'.
  Found 314 images for class 'New_Realism'.
  Found 2552 images for class 