In [4]:
# Import necessary libraries
import os
import glob
import random
import pandas as pd
import numpy as np
import requests
from io import BytesIO
from PIL import Image, ImageFilter
import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend suitable for script running
import matplotlib.pyplot as plt
from tqdm import tqdm  # For progress bars
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.nn as nn
import torch.optim as optim

import faiss

# Additional imports for label encoding and model evaluation
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

# Step 0: Define Helper Functions
print("Step 0: Defining helper functions...")

def show_results(query_image, results, query_index, transformed=False):
    """
    Displays the query image alongside its top matching results and saves the figure.

    Parameters:
    - query_image (PIL.Image.Image): PIL image of the query.
    - results (pd.DataFrame): DataFrame containing the top matching artworks.
    - query_index (int): Index of the query image.
    - transformed (bool): Indicates if the query image was transformed.

    Returns:
    - None
    """
    print("\nDisplaying results visually...")
    if transformed:
        # Visual indication that the query image was transformed
        query_image_display = query_image.filter(ImageFilter.GaussianBlur(radius=2))
        query_image_display = query_image_display.rotate(10)  # Example transformation
    else:
        query_image_display = query_image

    num_results = min(len(results), 5)
    plt.figure(figsize=(5 * (num_results + 1), 5))

    # Display Query Image
    plt.subplot(1, num_results + 1, 1)
    plt.imshow(query_image_display)
    title = 'Query Image'
    if transformed:
        title += ' (Transformed)'
    plt.title(title)
    plt.axis('off')

    # Display Matching Images
    for i in range(num_results):
        img_path = results.iloc[i]['image_path']
        try:
            img = Image.open(img_path).convert('RGB')
            plt.subplot(1, num_results + 1, i + 2)
            plt.imshow(img)
            plt.title(f"Match {i+1}")
            plt.axis('off')
            print(f"Loaded Match {i+1}: {img_path}")
        except Exception as e:
            print(f"Failed to load image {img_path}: {e}")

    output_dir = 'results'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    output_path = os.path.join(output_dir, f'query_results_{query_index}.png')
    plt.savefig(output_path)
    plt.close()
    print(f"Results displayed and saved to '{output_path}'.")

def get_image_features_batch(model, img_tensors):
    """
    Extracts and normalizes features from a batch of image tensors.

    Parameters:
    - model (torch.nn.Module): The feature extraction model.
    - img_tensors (torch.Tensor): Batch of image tensors.

    Returns:
    - np.ndarray: Normalized feature vectors.
    """
    print("  Extracting features from the batch...")
    with torch.no_grad():
        features = model(img_tensors)
        features = features.view(features.size(0), -1)  # Flatten
        features = features / features.norm(dim=1, keepdim=True)
    print("  Features extracted and normalized.")
    return features.cpu().numpy().astype('float32')

def search_artwork_batch(query_features, index, k=5):
    """
    Searches for the top k similar artworks using FAISS in batch.

    Parameters:
    - query_features (np.ndarray): Batch of feature vectors.
    - index (faiss.Index): FAISS index.
    - k (int): Number of top matches to retrieve.

    Returns:
    - distances (np.ndarray): Similarity scores.
    - indices (np.ndarray): Indices of the top matches.
    """
    print(f"  Searching for top {k} similar artworks in batch...")
    distances, indices = index.search(query_features, k)
    print("  Batch search completed.")
    return distances, indices

def get_artwork_info_batch(indices, image_df):
    """
    Retrieves artwork information based on the indices for a batch.

    Parameters:
    - indices (np.ndarray): Indices of the top matches.
    - image_df (pd.DataFrame): DataFrame containing image metadata.

    Returns:
    - List[pd.DataFrame]: List of DataFrames for each query's top matches.
    """
    print("  Retrieving artwork information for the batch...")
    try:
        results = [image_df.iloc[idx].reset_index(drop=True) for idx in indices]
        print("  Artwork information retrieved for the batch.")
    except Exception as e:
        print(f"  Error retrieving artwork information for the batch: {e}")
        results = [pd.DataFrame() for _ in range(indices.shape[0])]  # Return empty DataFrames on error
    return results

# Step 1: Load and Prepare the WikiArt Dataset (All Classes)
print("\nStep 1: Loading and preparing the WikiArt dataset (all classes)...")

# Specify the path to the WikiArt dataset
wikiart_dir = '../../scratch/mexas.v'  # Update this path as necessary

# Get the list of class names (styles, artists, or genres)
classes = [d for d in os.listdir(wikiart_dir) if os.path.isdir(os.path.join(wikiart_dir, d))]

# Prepare a list to hold image paths and labels
data = []

for label in classes:
    # Get all image paths for the current class
    image_paths = glob.glob(os.path.join(wikiart_dir, label, '*.jpg'))
    for path in image_paths:
        data.append({
            'image_path': path,
            'label': label
        })

# Create a DataFrame
image_df = pd.DataFrame(data)
print(f"Prepared image_df with {len(image_df)} entries from all classes.")

# Encode labels
label_encoder = LabelEncoder()
image_df['encoded_label'] = label_encoder.fit_transform(image_df['label'])

# Step 2: Define Image Transformations
print("\nStep 2: Defining image transformations...")

# Transformation for feature extraction (no augmentation)
feature_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Transformation for obfuscating images (rotation, blur, etc.)
# Note: Since we're applying transformations manually, no need for additional transforms here

print("Image transformations defined.")

# Step 3: Create Custom Dataset and DataLoader
print("\nStep 3: Creating custom datasets and data loaders...")

class ArtImageDataset(Dataset):
    def __init__(self, image_df, transform=None):
        self.image_df = image_df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.image_df)

    def __getitem__(self, idx):
        img_path = self.image_df.loc[idx, 'image_path']
        label = self.image_df.loc[idx, 'encoded_label']
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), (0, 0, 0))
        if self.transform:
            img = self.transform(img)
        return img, label  # Return image and label

# Split the dataset into training, validation, and test sets
print("\nSplitting dataset into training, validation, and test sets...")
train_df, temp_df = train_test_split(
    image_df, test_size=0.3, random_state=42, stratify=image_df['encoded_label']
)
val_df, test_df = train_test_split(
    temp_df, test_size=0.5, random_state=42, stratify=temp_df['encoded_label']
)
print(f"Training set: {len(train_df)} images")
print(f"Validation set: {len(val_df)} images")
print(f"Test set: {len(test_df)} images")

# Ensure there are no duplicates between the sets
train_paths = set(train_df['image_path'])
val_paths = set(val_df['image_path'])
test_paths = set(test_df['image_path'])

assert len(train_paths.intersection(val_paths)) == 0, "Overlap between training and validation sets!"
assert len(train_paths.intersection(test_paths)) == 0, "Overlap between training and test sets!"
assert len(val_paths.intersection(test_paths)) == 0, "Overlap between validation and test sets!"

# Create datasets
train_dataset = ArtImageDataset(train_df, transform=feature_transform)
val_dataset = ArtImageDataset(val_df, transform=feature_transform)
test_dataset = ArtImageDataset(test_df, transform=feature_transform)

# Create DataLoaders
batch_size = 32  # Adjust based on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

print("DataLoaders created for training, validation, and test sets.")

# Step 4: Set Up Device and Load Pre-trained Model
print("\nStep 4: Setting up the device and loading the pre-trained model...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the pre-trained ResNet-50 model
model = models.resnet50(pretrained=True)

# Modify the final layer for fine-tuning
num_ftrs = model.fc.in_features
num_classes = len(classes)
model.fc = nn.Linear(num_ftrs, num_classes)

model = model.to(device)

print("Pre-trained model loaded and modified for fine-tuning.")

# Step 5: Fine-Tune the Model
print("\nStep 5: Fine-tuning the model...")

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 4  # Adjust based on convergence
best_val_acc = 0.0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    # Training phase
    model.train()
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in tqdm(train_loader, desc="Training"):
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)

    print(f"Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

    # Validation phase
    model.eval()
    val_running_loss = 0.0
    val_running_corrects = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            val_running_loss += loss.item() * inputs.size(0)
            val_running_corrects += torch.sum(preds == labels.data)

    val_epoch_loss = val_running_loss / len(val_dataset)
    val_epoch_acc = val_running_corrects.double() / len(val_dataset)

    print(f"Validation Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}")

    # Checkpointing
    if val_epoch_acc > best_val_acc:
        best_val_acc = val_epoch_acc
        torch.save(model.state_dict(), 'best_resnet50_wikiart_finetuned.pth')
        print("Best model updated.")

print(f"\nTraining complete. Best Validation Accuracy: {best_val_acc:.4f}")

# Load the best model
model.load_state_dict(torch.load('best_resnet50_wikiart_finetuned.pth'))
model.eval()
print("Best fine-tuned model loaded.")

# Step 6: Extract Features with Conditional Loading
print("\nStep 6: Extracting features from training set...")

features_file = 'wikiart_features_finetuned.npy'

if os.path.exists(features_file):
    print(f"Feature file '{features_file}' found. Loading existing features...")
    try:
        features_np = np.load(features_file)
        labels = np.load('wikiart_labels_finetuned.npy')
        print(f"Loaded features with shape: {features_np.shape}")
    except Exception as e:
        print(f"Error loading features from '{features_file}': {e}")
        print("Proceeding to extract features anew.")

        # Extract features as the file could not be loaded
        features, labels = extract_features(model, train_loader, device)
        print(f"Extracted features shape: {features.shape}")

        # Convert features to numpy array
        features_np = features.numpy().astype('float32')

        # Save features and labels
        np.save(features_file, features_np)
        np.save('wikiart_labels_finetuned.npy', labels)
        print(f"Features extracted and saved to '{features_file}'.")
else:
    print(f"Feature file '{features_file}' not found. Extracting features...")

    # Extract features
    features, labels = extract_features(model, train_loader, device)
    print(f"Extracted features shape: {features.shape}")

    # Convert features to numpy array
    features_np = features.numpy().astype('float32')

    # Save features and labels
    np.save(features_file, features_np)
    np.save('wikiart_labels_finetuned.npy', labels)
    print(f"Features extracted and saved to '{features_file}'.")

# Step 7: Create FAISS Index with Extracted Features
print("\nStep 7: Creating FAISS index with extracted features...")

index_file = 'wikiart_faiss_index_finetuned.bin'

# Use FAISS with GPU support
res = faiss.StandardGpuResources()  # Use a single GPU

if os.path.exists(index_file):
    print(f"FAISS index file '{index_file}' found. Loading existing index...")
    try:
        index_cpu = faiss.read_index(index_file)
        index = faiss.index_cpu_to_gpu(res, 0, index_cpu)
        print(f"Loaded FAISS index from '{index_file}' and moved to GPU.")
    except Exception as e:
        print(f"Error loading FAISS index from '{index_file}': {e}")
        print("Proceeding to create a new FAISS index.")

        # Create FAISS index
        index_cpu = faiss.IndexFlatIP(features_np.shape[1])
        index_cpu.add(features_np)
        index = faiss.index_cpu_to_gpu(res, 0, index_cpu)
        faiss.write_index(index_cpu, index_file)
        print(f"New FAISS index created and saved to '{index_file}' and moved to GPU.")
else:
    print(f"FAISS index file '{index_file}' not found. Creating a new index...")

    # Create FAISS index
    index_cpu = faiss.IndexFlatIP(features_np.shape[1])
    index_cpu.add(features_np)
    index = faiss.index_cpu_to_gpu(res, 0, index_cpu)
    faiss.write_index(index_cpu, index_file)
    print(f"FAISS index created and saved to '{index_file}' and moved to GPU.")

# Step 8: Evaluate the Model on the Test Set
print("\nStep 8: Evaluating the model on the test set...")

test_features_file = 'wikiart_test_features_finetuned.npy'

if os.path.exists(test_features_file):
    print(f"Feature file '{test_features_file}' found. Loading existing features...")
    try:
        test_features_np = np.load(test_features_file)
        test_labels = np.load('wikiart_test_labels_finetuned.npy')
        print(f"Loaded test features with shape: {test_features_np.shape}")
    except Exception as e:
        print(f"Error loading features from '{test_features_file}': {e}")
        print("Proceeding to extract test features anew.")

        # Extract features as the file could not be loaded
        test_features, test_labels = extract_features(model, test_loader, device)
        print(f"Extracted test features shape: {test_features.shape}")

        # Convert features to numpy array
        test_features_np = test_features.numpy().astype('float32')

        # Save features and labels
        np.save(test_features_file, test_features_np)
        np.save('wikiart_test_labels_finetuned.npy', test_labels)
        print(f"Test features extracted and saved to '{test_features_file}'.")
else:
    print(f"Feature file '{test_features_file}' not found. Extracting test features...")

    # Extract test features
    test_features, test_labels = extract_features(model, test_loader, device)
    print(f"Extracted test features shape: {test_features.shape}")

    # Convert features to numpy array
    test_features_np = test_features.numpy().astype('float32')

    # Save features and labels
    np.save(test_features_file, test_features_np)
    np.save('wikiart_test_labels_finetuned.npy', test_labels)
    print(f"Test features extracted and saved to '{test_features_file}'.")

# Perform batch similarity search for all test images
print("\nPerforming similarity search on the test set...")

k = 5  # Number of top matches to retrieve

total_queries = len(test_features_np)
batch_size_eval = 64  # Adjust based on your GPU memory
num_batches = (total_queries + batch_size_eval - 1) // batch_size_eval

correct_at_1 = 0
correct_at_k = 0

for batch_idx in tqdm(range(num_batches), desc="Evaluating"):
    start_idx = batch_idx * batch_size_eval
    end_idx = min(start_idx + batch_size_eval, total_queries)
    query_features_batch = test_features_np[start_idx:end_idx]
    true_labels_batch = test_labels[start_idx:end_idx]

    distances, indices = index.search(query_features_batch, k)

    # Retrieve labels for the indices
    retrieved_labels_batch = labels[indices]  # labels correspond to train features

    for i in range(end_idx - start_idx):
        true_label = true_labels_batch[i]
        retrieved_labels = retrieved_labels_batch[i]

        if retrieved_labels[0] == true_label:
            correct_at_1 += 1

        if true_label in retrieved_labels:
            correct_at_k += 1

top1_accuracy = correct_at_1 / total_queries * 100
topk_accuracy = correct_at_k / total_queries * 100

print(f"\nEvaluation Results:")
print(f"Total Queries: {total_queries}")
print(f"Top-1 Accuracy: {top1_accuracy:.2f}%")
print(f"Top-{k} Accuracy: {topk_accuracy:.2f}%")


Step 0: Defining helper functions...

Step 1: Loading and preparing the WikiArt dataset (all classes)...
Prepared image_df with 200111 entries from all classes.

Step 2: Defining image transformations...
Image transformations defined.

Step 3: Creating custom datasets and data loaders...

Splitting dataset into training, validation, and test sets...
Training set: 140077 images
Validation set: 30017 images
Test set: 30017 images
DataLoaders created for training, validation, and test sets.

Step 4: Setting up the device and loading the pre-trained model...
Using device: cuda
Pre-trained model loaded and modified for fine-tuning.

Step 5: Fine-tuning the model...

Epoch 1/4


Training:   4%|▍         | 171/4378 [04:45<1:57:13,  1.67s/it]


KeyboardInterrupt: 