In [None]:
# Import necessary libraries
import os
import glob
import random
import pandas as pd
import numpy as np
import requests
from io import BytesIO
from PIL import Image
import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend suitable for script running
import matplotlib.pyplot as plt
from tqdm import tqdm  # For progress bars
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.nn as nn

import faiss

# Additional imports for label encoding
from sklearn.preprocessing import LabelEncoder

# Step 0: Define Helper Functions
print("Step 0: Defining helper functions...")

def show_results(image_path_or_url, results, query_index):
    """
    Displays the query image alongside its top matching results and saves the figure.

    Parameters:
    - image_path_or_url (str): Path or URL of the query image.
    - results (pd.DataFrame): DataFrame containing the top matching artworks.
    - query_index (int): Index of the query image.

    Returns:
    - None
    """
    print("\nDisplaying results visually...")
    try:
        if image_path_or_url.startswith('http'):
            response = requests.get(image_path_or_url, timeout=10)
            query_img = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            query_img = Image.open(image_path_or_url).convert('RGB')
    except Exception as e:
        print(f"Failed to load query image: {e}")
        return

    num_results = min(len(results), 5)
    plt.figure(figsize=(5 * (num_results + 1), 5))

    # Display Query Image
    plt.subplot(1, num_results + 1, 1)
    plt.imshow(query_img)
    plt.title('Query Image')
    plt.axis('off')

    # Display Matching Images
    for i in range(num_results):
        img_path = results.iloc[i]['image_path']
        try:
            img = Image.open(img_path).convert('RGB')
            plt.subplot(1, num_results + 1, i + 2)
            plt.imshow(img)
            plt.title(f"Match {i+1}")
            plt.axis('off')
            print(f"Loaded Match {i+1}: {img_path}")
        except Exception as e:
            print(f"Failed to load image {img_path}: {e}")

    output_dir = 'results'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    output_path = os.path.join(output_dir, f'query_results_{query_index}.png')
    plt.savefig(output_path)
    plt.close()
    print(f"Results displayed and saved to '{output_path}'.")

def process_input_image(image_path_or_url, transform, device):
    """
    Processes the input image by loading, transforming, and sending it to the device.

    Parameters:
    - image_path_or_url (str): Path or URL of the image.
    - transform (torchvision.transforms.Compose): Transformations to apply.
    - device (torch.device): Device to send the image tensor.

    Returns:
    - torch.Tensor: Processed image tensor.
    """
    print(f"  Processing input image: {image_path_or_url}")
    try:
        if image_path_or_url.startswith('http'):
            response = requests.get(image_path_or_url, timeout=10)
            img = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            img = Image.open(image_path_or_url).convert('RGB')
    except Exception as e:
        print(f"  Failed to load image {image_path_or_url}: {e}")
        # Return a black image in case of error
        img = Image.new('RGB', (224, 224), (0, 0, 0))
    img = transform(img).unsqueeze(0).to(device)
    print("  Image processed and transformed.")
    return img

def get_image_features(model, img_tensor):
    """
    Extracts and normalizes features from the image tensor.

    Parameters:
    - model (torch.nn.Module): The feature extraction model.
    - img_tensor (torch.Tensor): Image tensor.

    Returns:
    - np.ndarray: Normalized feature vector.
    """
    print("  Extracting features from the query image...")
    with torch.no_grad():
        features = model(img_tensor)
        features = features.view(features.size(0), -1)  # Flatten
        features = features / features.norm(dim=1, keepdim=True)
    print("  Features extracted and normalized.")
    return features.cpu().numpy().astype('float32')

def search_artwork(query_features, index, k=5):
    """
    Searches for the top k similar artworks using FAISS.

    Parameters:
    - query_features (np.ndarray): Feature vector of the query image.
    - index (faiss.Index): FAISS index.
    - k (int): Number of top matches to retrieve.

    Returns:
    - distances (np.ndarray): Similarity scores.
    - indices (np.ndarray): Indices of the top matches.
    """
    print(f"  Searching for top {k} similar artworks...")
    distances, indices = index.search(query_features, k)
    print("  Search completed.")
    return distances[0], indices[0]

def get_artwork_info(indices, image_df):
    """
    Retrieves artwork information based on the indices.

    Parameters:
    - indices (np.ndarray): Indices of the top matches.
    - image_df (pd.DataFrame): DataFrame containing image metadata.

    Returns:
    - pd.DataFrame: DataFrame of the top matching artworks.
    """
    print("  Retrieving artwork information for the top matches...")
    try:
        results = image_df.iloc[indices].reset_index(drop=True)
        print("  Artwork information retrieved.")
    except Exception as e:
        print(f"  Error retrieving artwork information: {e}")
        results = pd.DataFrame()  # Return empty DataFrame on error
    return results

def identify_artwork(image_path_or_url, model, index, image_df, transform, device, k=5):
    """
    Identifies the top k artworks similar to the query image.

    Parameters:
    - image_path_or_url (str): Path or URL of the query image.
    - model (torch.nn.Module): Feature extraction model.
    - index (faiss.Index): FAISS index.
    - image_df (pd.DataFrame): DataFrame containing image metadata.
    - transform (torchvision.transforms.Compose): Transformations to apply.
    - device (torch.device): Device to send the image tensor.
    - k (int): Number of top matches to retrieve.

    Returns:
    - results (pd.DataFrame): DataFrame of the top matching artworks.
    - distances (np.ndarray): Similarity scores.
    """
    print("\nIdentifying artwork...")
    img_tensor = process_input_image(image_path_or_url, transform, device)
    query_features = get_image_features(model, img_tensor)
    distances, indices = search_artwork(query_features, index, k)
    results = get_artwork_info(indices, image_df)
    print("Artwork identification completed.")
    return results, distances

def extract_features(model, dataloader, device):
    """
    Extracts features from images using the provided model and dataloader.

    Parameters:
    - model (torch.nn.Module): The feature extraction model.
    - dataloader (DataLoader): DataLoader for the dataset.
    - device (torch.device): Device to perform computations on.

    Returns:
    - torch.Tensor: Extracted features.
    """
    print("Starting feature extraction...")
    features = []
    with torch.no_grad():
        for batch_idx, (imgs, _) in enumerate(tqdm(dataloader, desc="Extracting features"), 1):
            imgs = imgs.to(device)
            try:
                outputs = model(imgs)
                outputs = outputs.view(outputs.size(0), -1)  # Flatten to (batch_size, feature_dim)
                outputs = outputs / outputs.norm(dim=1, keepdim=True)
                features.append(outputs.cpu())
            except Exception as e:
                print(f"Error processing batch {batch_idx}: {e}")
    features = torch.cat(features, dim=0)
    print("Feature extraction completed.")
    return features

# Step 1: Load and Prepare the WikiArt Dataset (Impressionism Only)
print("\nStep 1: Loading and preparing the WikiArt dataset (Impressionism only)...")

# Specify the path to the WikiArt dataset
wikiart_dir = '../../scratch/mexas.v'

# Only use the 'Impressionism' folder
selected_class = 'Impressionism'

# Verify that the folder exists
class_path = os.path.join(wikiart_dir, selected_class)
if not os.path.exists(class_path):
    print(f"Class folder '{selected_class}' not found in '{wikiart_dir}'.")
    exit(1)

# Prepare a list to hold image paths and labels
data = []

# Get all image paths for the selected class
image_paths = glob.glob(os.path.join(class_path, '*.jpg'))
for path in image_paths:
    data.append({
        'image_path': path,
        'label': selected_class
    })

# Create a DataFrame
image_df = pd.DataFrame(data)
print(f"Prepared image_df with {len(image_df)} entries from class '{selected_class}'.")

# Encode labels (though we have only one class here)
label_encoder = LabelEncoder()
image_df['encoded_label'] = label_encoder.fit_transform(image_df['label'])

# Step 2: Define Image Transformations
print("\nStep 2: Defining image transformations...")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

print("Image transformations defined.")

# Step 3: Create Custom Dataset and DataLoader
print("\nStep 3: Creating custom dataset and data loader...")

class ArtImageDataset(Dataset):
    def __init__(self, image_df, transform=None):
        self.image_df = image_df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.image_df)

    def __getitem__(self, idx):
        img_path = self.image_df.loc[idx, 'image_path']
        label = self.image_df.loc[idx, 'encoded_label']
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), (0, 0, 0))
        if self.transform:
            img = self.transform(img)
        return img, idx  # Return index to keep track

# Create dataset and DataLoader
batch_size = 32  # Adjust based on your GPU memory
dataset = ArtImageDataset(image_df, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=12)

print(f"DataLoader created with {len(dataloader)} batches.")

# Step 4: Set Up Device and Load Pre-trained Model
print("\nStep 4: Setting up the device and loading the pre-trained model...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the pre-trained ResNet-50 model
model = models.resnet50(pretrained=True)

# Remove the final classification layer for feature extraction
model = nn.Sequential(*list(model.children())[:-1])
model = model.to(device)
model.eval()

print("Pre-trained model loaded and modified for feature extraction.")

# Step 5: Extract Features
print("\nStep 5: Extracting features from all images...")

features = extract_features(model, dataloader, device)
print(f"Extracted features shape: {features.shape}")

# Convert features to numpy array
features_np = features.squeeze().numpy().astype('float32')

# Save features
np.save('impressionism_features.npy', features_np)
print("Features saved to 'impressionism_features.npy'.")

# Step 6: Create FAISS Index and Add Features
print("\nStep 6: Creating FAISS index with extracted features...")

index = faiss.IndexFlatIP(features_np.shape[1])
index.add(features_np)
faiss.write_index(index, 'impressionism_faiss_index.bin')
print("FAISS index created with extracted features and saved to 'impressionism_faiss_index.bin'.")

# Step 7: Perform Queries on Random Sample of Images
print("\nStep 7: Performing queries on a random sample of images...")

num_queries = 100  # Number of random images to query
if num_queries > len(image_df):
    num_queries = len(image_df)
    print(f"Only {num_queries} images available. Adjusting number of queries to {num_queries}.")

random_indices = random.sample(range(len(image_df)), num_queries)

correct_matches = 0  # Counter for correct matches

for idx, query_image_index in enumerate(random_indices):
    print(f"\nQuery {idx+1}/{num_queries}")
    image_to_query = image_df.iloc[query_image_index]['image_path']
    print(f"Selected image for querying: {image_to_query}")

    try:
        # Use the same model and FAISS index
        results, distances = identify_artwork(
            image_to_query,
            model,
            index,
            image_df,
            transform,
            device,
            k=5
        )
    except Exception as e:
        print(f"Error identifying artwork: {e}")
        continue

    # Check if the top match is the same as the query image
    top_match_index = results.index[0]
    if top_match_index == query_image_index:
        correct_matches += 1
        match_status = "Correct Match"
    else:
        match_status = "Incorrect Match"

    # Display Top Matching Artworks
    print(f"\nTop matching artworks for Query {idx+1}:")
    for i, (result_idx, dist) in enumerate(zip(results.index, distances)):
        row = results.iloc[i]
        print(f"Rank {i+1}:")
        print(f"  Label: {row['label']}")
        print(f"  Similarity Score: {dist:.4f}")
        print(f"  Image Path: {row['image_path']}\n")

    # Display and save results
    show_results(image_to_query, results, query_index=idx+1)

    print(f"Query {idx+1} result: {match_status}")

# Step 8: Print Summary Statistics
print("\nSummary:")
print(f"Total Queries: {num_queries}")
print(f"Correct Matches: {correct_matches}")
print(f"Accuracy: {correct_matches / num_queries * 100:.2f}%")


Step 0: Defining helper functions...

Step 1: Loading and preparing the WikiArt dataset (Impressionism only)...
Prepared image_df with 13060 entries from class 'Impressionism'.

Step 2: Defining image transformations...
Image transformations defined.

Step 3: Creating custom dataset and data loader...
DataLoader created with 409 batches.

Step 4: Setting up the device and loading the pre-trained model...
Using device: cuda
Pre-trained model loaded and modified for feature extraction.

Step 5: Extracting features from all images...
Starting feature extraction...


Extracting features:  70%|███████   | 288/409 [07:51<00:50,  2.39it/s]