In [3]:
# Import necessary libraries
import os
import glob
import pandas as pd
import numpy as np
import requests
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm  # For progress bars
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.nn as nn

import faiss

# Additional imports for fine-tuning
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# Ensure inline plotting for Jupyter notebooks (if using a notebook)
# %matplotlib inline

# Step 0: Define Helper Functions
print("Step 0: Defining helper functions...")

def download_images(final_df, save_dir='./art_images/'):
    """
    Downloads images from URLs specified in the final_df DataFrame.

    Parameters:
    - final_df (pd.DataFrame): DataFrame containing 'objectid' and 'image_urls' columns.
    - save_dir (str): Directory where images will be saved.

    Returns:
    - None
    """
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print(f"Created directory: {save_dir}")
    else:
        print(f"Directory already exists: {save_dir}")

    print("Starting image downloads...")
    for idx, row in tqdm(final_df.iterrows(), total=final_df.shape[0], desc="Downloading images"):
        objectid = str(row['objectid'])
        image_urls = row['image_urls']
        for i, image_url in enumerate(image_urls):
            image_filename = f"{objectid}_{i}.jpg"
            image_path = os.path.join(save_dir, image_filename)

            # Skip downloading if the file already exists
            if os.path.exists(image_path):
                continue

            try:
                response = requests.get(image_url, timeout=10)
                response.raise_for_status()  # Check if the request was successful
                with open(image_path, 'wb') as f:
                    f.write(response.content)
            except requests.exceptions.RequestException as e:
                print(f"Failed to download {image_filename} from {image_url}: {e}")

def show_results(image_path_or_url, results):
    """
    Displays the query image alongside its top matching results.

    Parameters:
    - image_path_or_url (str): Path or URL of the query image.
    - results (pd.DataFrame): DataFrame containing the top matching artworks.

    Returns:
    - None
    """
    print("\nDisplaying results visually...")
    try:
        if image_path_or_url.startswith('http'):
            response = requests.get(image_path_or_url, timeout=10)
            query_img = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            query_img = Image.open(image_path_or_url).convert('RGB')
    except Exception as e:
        print(f"Failed to load query image: {e}")
        return

    num_results = min(len(results), 5)
    plt.figure(figsize=(5 * (num_results + 1), 5))

    # Display Query Image
    plt.subplot(1, num_results + 1, 1)
    plt.imshow(query_img)
    plt.title('Query Image')
    plt.axis('off')

    # Display Matching Images
    for i in range(num_results):
        img_path = results.iloc[i]['image_path']
        try:
            img = Image.open(img_path).convert('RGB')
            plt.subplot(1, num_results + 1, i + 2)
            plt.imshow(img)
            plt.title(f"Match {i+1}")
            plt.axis('off')
            print(f"Loaded Match {i+1}: {img_path}")
        except Exception as e:
            print(f"Failed to load image {img_path}: {e}")

    plt.show()
    print("Results displayed.")

def process_input_image(image_path_or_url, transform, device):
    """
    Processes the input image by loading, transforming, and sending it to the device.

    Parameters:
    - image_path_or_url (str): Path or URL of the image.
    - transform (torchvision.transforms.Compose): Transformations to apply.
    - device (torch.device): Device to send the image tensor.

    Returns:
    - torch.Tensor: Processed image tensor.
    """
    print(f"  Processing input image: {image_path_or_url}")
    try:
        if image_path_or_url.startswith('http'):
            response = requests.get(image_path_or_url, timeout=10)
            img = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            img = Image.open(image_path_or_url).convert('RGB')
    except Exception as e:
        print(f"  Failed to load image {image_path_or_url}: {e}")
        # Return a black image in case of error
        img = Image.new('RGB', (224, 224), (0, 0, 0))
    img = transform(img).unsqueeze(0).to(device)
    print("  Image processed and transformed.")
    return img

def get_image_features(model, img_tensor):
    """
    Extracts and normalizes features from the image tensor.

    Parameters:
    - model (torch.nn.Module): The feature extraction model.
    - img_tensor (torch.Tensor): Image tensor.

    Returns:
    - np.ndarray: Normalized feature vector.
    """
    print("  Extracting features from the query image...")
    with torch.no_grad():
        features = model(img_tensor)
        features = features.view(features.size(0), -1)  # Flatten
        features = features / features.norm()
    print("  Features extracted and normalized.")
    return features.cpu().numpy().astype('float32')

def search_artwork(query_features, index, k=5):
    """
    Searches for the top k similar artworks using FAISS.

    Parameters:
    - query_features (np.ndarray): Feature vector of the query image.
    - index (faiss.Index): FAISS index.
    - k (int): Number of top matches to retrieve.

    Returns:
    - distances (np.ndarray): Similarity scores.
    - indices (np.ndarray): Indices of the top matches.
    """
    print(f"  Searching for top {k} similar artworks...")
    distances, indices = index.search(query_features, k)
    print("  Search completed.")
    return distances[0], indices[0]

def get_artwork_info(indices, image_df):
    """
    Retrieves artwork information based on the indices.

    Parameters:
    - indices (np.ndarray): Indices of the top matches.
    - image_df (pd.DataFrame): DataFrame containing image metadata.

    Returns:
    - pd.DataFrame: DataFrame of the top matching artworks.
    """
    print("  Retrieving artwork information for the top matches...")
    try:
        results = image_df.iloc[indices].reset_index(drop=True)
        print("  Artwork information retrieved.")
    except Exception as e:
        print(f"  Error retrieving artwork information: {e}")
        results = pd.DataFrame()  # Return empty DataFrame on error
    return results

def identify_artwork(image_path_or_url, model, index, image_df, transform, device, k=5):
    """
    Identifies the top k artworks similar to the query image.

    Parameters:
    - image_path_or_url (str): Path or URL of the query image.
    - model (torch.nn.Module): Feature extraction model.
    - index (faiss.Index): FAISS index.
    - image_df (pd.DataFrame): DataFrame containing image metadata.
    - transform (torchvision.transforms.Compose): Transformations to apply.
    - device (torch.device): Device to send the image tensor.
    - k (int): Number of top matches to retrieve.

    Returns:
    - results (pd.DataFrame): DataFrame of the top matching artworks.
    - distances (np.ndarray): Similarity scores.
    """
    print("\nIdentifying artwork...")
    img_tensor = process_input_image(image_path_or_url, transform, device)
    query_features = get_image_features(model, img_tensor)
    distances, indices = search_artwork(query_features, index, k)
    results = get_artwork_info(indices, image_df)
    print("Artwork identification completed.")
    return results, distances

def extract_features(model, dataloader, device):
    """
    Extracts features from images using the provided model and dataloader.

    Parameters:
    - model (torch.nn.Module): The feature extraction model.
    - dataloader (DataLoader): DataLoader for the dataset.
    - device (torch.device): Device to perform computations on.

    Returns:
    - torch.Tensor: Extracted features.
    """
    print("Starting feature extraction...")
    features = []
    with torch.no_grad():
        for batch_idx, (imgs, idxs) in enumerate(tqdm(dataloader, desc="Extracting features"), 1):
            imgs = imgs.to(device)
            try:
                outputs = model(imgs)
                outputs = outputs.view(outputs.size(0), -1)  # Flatten to (batch_size, feature_dim)
                features.append(outputs.cpu())
            except Exception as e:
                print(f"Error processing batch {batch_idx}: {e}")
    features = torch.cat(features, dim=0)
    print("Feature extraction completed.")
    return features

# Step 1: Load and Prepare Datasets
print("\nStep 1: Loading and preparing datasets...")

try:
    objects_df = pd.read_csv('./datasets/objects.csv')
    constituents_df = pd.read_csv('./datasets/constituents.csv')
    images_df = pd.read_csv('./datasets/published_images.csv')
    print("Datasets loaded successfully.")
except Exception as e:
    print(f"Error loading datasets: {e}")
    raise

# Ensure relevant columns are strings and stripped
for df, key in [
    (objects_df, 'objectid'),
    (constituents_df, 'artistofngaobject'),
    (images_df, 'depictstmsobjectid')
]:
    df[key] = df[key].astype(str).str.strip()

# Aggregate artists
print("Aggregating artist information...")
artists_agg = constituents_df.groupby('artistofngaobject')['preferreddisplayname'] \
    .apply(lambda names: ', '.join(sorted(set(names)))).reset_index()
artists_agg.rename(columns={
    'artistofngaobject': 'objectid',
    'preferreddisplayname': 'artists'
}, inplace=True)

# Aggregate image URLs
print("Aggregating image URLs...")
images_agg = images_df.groupby('depictstmsobjectid')['iiifthumburl'] \
    .apply(list).reset_index()
images_agg.rename(columns={
    'depictstmsobjectid': 'objectid',
    'iiifthumburl': 'image_urls'
}, inplace=True)

# Merge datasets
print("Merging datasets into final_df...")
merged_with_artists = objects_df.merge(
    artists_agg,
    on='objectid',
    how='left'
)
final_df = merged_with_artists.merge(
    images_agg,
    on='objectid',
    how='left'
)

# Select and clean relevant columns
final_df = final_df[['objectid', 'title', 'artists', 'image_urls']]
final_df['artists'] = final_df['artists'].fillna('Unknown Artist')
final_df['image_urls'] = final_df['image_urls'].apply(lambda x: x if isinstance(x, list) else [])
final_df = final_df[final_df['image_urls'].map(len) > 0].reset_index(drop=True)

# Print summary statistics
print("\nSummary of final_df:")
print(f"Unique artworks in objects_df: {objects_df['objectid'].nunique()}")
print(f"Unique artists in constituents_df: {constituents_df['constituentid'].nunique()}")
print(f"Unique image URLs in images_df: {images_df['iiifthumburl'].nunique()}")
print(f"Unique artworks in final_df: {final_df['objectid'].nunique()}")
print(f"Columns in final_df: {final_df.columns.tolist()}")

# Step 2: Download Images (Optional)
# Uncomment the following lines if you need to download images
# print("\nStep 2: Downloading images...")
# download_images(final_df)

print("\nStep 3: Preparing image_df from downloaded images...")

# Check if 'image_df.pkl' already exists
if os.path.exists('image_df.pkl'):
    # Load image_df from the pickle file
    image_df = pd.read_pickle('image_df.pkl')
    print(f"Loaded image_df with {len(image_df)} entries from 'image_df.pkl'.")
else:
    image_paths = glob.glob('./art_images/*.jpg')
    print(f"Found {len(image_paths)} images in './art_images/' directory.")
    
    # Use all images
    subset_image_paths = image_paths
    print(f"Selected {len(subset_image_paths)} images for processing.")
    
    # Prepare image_data
    image_data = []
    for path in tqdm(subset_image_paths, desc="Preparing image data"):
        filename = os.path.basename(path)
        objectid = filename.split('_')[0]
        row = final_df[final_df['objectid'] == objectid]
        if row.empty:
            # Skip images without metadata
            continue
        row = row.iloc[0]
        image_data.append({
            'image_path': path,
            'objectid': objectid,
            'artists': row['artists'],
            'title': row['title']
        })
    
    # Create image_df
    image_df = pd.DataFrame(image_data)
    print(f"Prepared image_df with {len(image_df)} entries.")
    
    # Save image_df to a pickle file for future use
    image_df.to_pickle('image_df.pkl')
    print("Saved image_df to 'image_df.pkl' for future use.")

# Step 4: Preparing Dataset for Fine-Tuning
print("\nStep 4: Preparing dataset for fine-tuning...")

# Analyze the number of samples per title and artist
title_counts = image_df['title'].value_counts()
artist_counts = image_df['artists'].value_counts()

print(f"Number of unique titles: {len(title_counts)}")
print(f"Number of unique artists: {len(artist_counts)}")

# Decide whether to classify by title or artist based on counts
min_samples_per_class = 2

# Check if titles have enough samples
titles_with_enough_samples = title_counts[title_counts >= min_samples_per_class]
if len(titles_with_enough_samples) > 0:
    print(f"Proceeding with classification by title using {len(titles_with_enough_samples)} titles.")
    # Filter image_df
    image_df = image_df[image_df['title'].isin(titles_with_enough_samples.index)]
    # Encode labels
    label_encoder = LabelEncoder()
    image_df['label'] = label_encoder.fit_transform(image_df['title'])
else:
    print("Not enough samples per title. Proceeding with classification by artist.")
    # Filter artists with enough samples
    artists_with_enough_samples = artist_counts[artist_counts >= min_samples_per_class]
    image_df = image_df[image_df['artists'].isin(artists_with_enough_samples.index)]
    # Encode labels
    label_encoder = LabelEncoder()
    image_df['label'] = label_encoder.fit_transform(image_df['artists'])

print(f"Filtered image_df now has {len(image_df)} entries.")

# Split the dataset into training and validation sets
print("\nSplitting dataset into training and validation sets...")
train_df, val_df = train_test_split(
    image_df,
    test_size=0.2,
    random_state=42,
    # stratify=image_df['label']
)

print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")

# Step 5: Define Image Transformations with Data Augmentation
print("\nStep 5: Defining image transformations with data augmentation...")

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

print("Image transformations defined.")

# Step 6: Create Custom Dataset and DataLoaders
print("\nStep 6: Creating custom dataset and data loaders...")

class ArtImageDataset(Dataset):
    def __init__(self, image_df, transform=None):
        self.image_df = image_df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.image_df)

    def __getitem__(self, idx):
        img_path = self.image_df.loc[idx, 'image_path']
        label = self.image_df.loc[idx, 'label']

        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), (0, 0, 0))

        if self.transform:
            img = self.transform(img)

        return img, label

# Create datasets
train_dataset = ArtImageDataset(train_df, transform=train_transform)
val_dataset = ArtImageDataset(val_df, transform=val_transform)

# Create DataLoaders
batch_size = 32  # Adjust based on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=12)

print(f"Training DataLoader created with {len(train_loader)} batches.")
print(f"Validation DataLoader created with {len(val_loader)} batches.")

# Step 7: Set Up Device and Modify the Model
print("\nStep 7: Setting up the device and modifying the ResNet-50 model for fine-tuning...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Get the number of classes
num_classes = image_df['label'].nunique()
print(f"Number of classes: {num_classes}")

# Load the pre-trained ResNet-50 model
model_ft = models.resnet50(pretrained=True)

# Optionally freeze initial layers
for param in model_ft.parameters():
    param.requires_grad = False

# Replace the final fully connected layer
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)

# Move the model to the device
model_ft = model_ft.to(device)
print("Model modified for fine-tuning.")

# Step 8: Define Loss Function and Optimizer
print("\nStep 8: Defining loss function and optimizer...")

criterion = nn.CrossEntropyLoss()

# Only parameters of the final layer are being optimized
optimizer = torch.optim.Adam(model_ft.fc.parameters(), lr=0.001)

# Optionally, use a learning rate scheduler
from torch.optim import lr_scheduler
scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

print("Loss function and optimizer defined.")

num_epochs = 5  # Adjust based on your needs

# Step 9: Implement the Training Loop
print("\nStep 9: Starting the training loop...")

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 10)

    # Training phase
    model_ft.train()
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in tqdm(train_loader, desc="Training"):
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model_ft(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    scheduler.step()  # Step the learning rate scheduler

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)

    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc.item())

    print(f"Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

    # Validation phase
    model_ft.eval()
    val_running_loss = 0.0
    val_running_corrects = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model_ft(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            val_running_loss += loss.item() * inputs.size(0)
            val_running_corrects += torch.sum(preds == labels.data)

    val_epoch_loss = val_running_loss / len(val_dataset)
    val_epoch_acc = val_running_corrects.double() / len(val_dataset)

    val_losses.append(val_epoch_loss)
    val_accuracies.append(val_epoch_acc.item())

    print(f"Validation Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}")

# After training, create plots and save them
print("\nStep 9.1: Plotting and saving training curves...")

# Ensure the 'training_images' directory exists
if not os.path.exists('training_images'):
    os.makedirs('training_images')
    print("Created directory 'training_images'.")

# Plot Training and Validation Loss
epochs_range = range(1, num_epochs + 1)
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, train_losses, 'b-', label='Training Loss')
plt.plot(epochs_range, val_losses, 'r-', label='Validation Loss')
plt.title(f'Training and Validation Loss\nModel: ResNet50, Epochs: {num_epochs}')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig('training_images/loss_plot.png')
plt.close()
print("Saved loss plot to 'training_images/loss_plot.png'.")

# Plot Training and Validation Accuracy
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, train_accuracies, 'b-', label='Training Accuracy')
plt.plot(epochs_range, val_accuracies, 'r-', label='Validation Accuracy')
plt.title(f'Training and Validation Accuracy\nModel: ResNet50, Epochs: {num_epochs}')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.savefig('training_images/accuracy_plot.png')
plt.close()
print("Saved accuracy plot to 'training_images/accuracy_plot.png'.")

# Step 10: Save the Fine-Tuned Model
print("\nStep 10: Saving the fine-tuned model...")
    
torch.save(model_ft.state_dict(), 'fine_tuned_resnet50_artworks.pth')
print("Fine-tuned model saved to 'fine_tuned_resnet50_artworks.pth'.")

# Step 11: Load Fine-Tuned Model and Extract Features
print("\nStep 11: Loading fine-tuned model and extracting features...")

# Load the fine-tuned model weights
model_ft.load_state_dict(torch.load('fine_tuned_resnet50_artworks.pth'))
print("Loaded fine-tuned model weights.")

# Remove the classification layer for feature extraction
model_extractor = nn.Sequential(*list(model_ft.children())[:-1])
model_extractor = model_extractor.to(device)
model_extractor.eval()

# Use the same DataLoader without data augmentation for feature extraction
feature_dataset = ArtImageDataset(image_df, transform=val_transform)
feature_loader = DataLoader(feature_dataset, batch_size=batch_size, shuffle=False, num_workers=12)

# Extract features
art_features = extract_features(model_extractor, feature_loader, device)
print(f"Extracted features shape: {art_features.shape}")

# Normalize features
art_features_norm = art_features / art_features.norm(dim=1, keepdim=True)
art_features_np = art_features_norm.numpy().astype('float32')

# Save features
np.save('fine_tuned_art_features.npy', art_features_np)
print("Features from fine-tuned model saved to 'fine_tuned_art_features.npy'.")

# Step 12: Create FAISS Index and Add Features
print("\nStep 12: Creating FAISS index with fine-tuned features...")

index = faiss.IndexFlatIP(art_features_np.shape[1])
index.add(art_features_np)
faiss.write_index(index, 'fine_tuned_faiss_index.bin')
print("FAISS index created with fine-tuned features and saved to 'fine_tuned_faiss_index.bin'.")

# Step 13: Perform the Query on a Sample Image
print("\nStep 13: Performing query on a sample image...")

# Use any image from the dataset as the query
query_image_index = 80  # Adjust index as needed
if query_image_index >= len(image_df):
    query_image_index = len(image_df) - 1  # Ensure the index is within range
image_to_query = image_df.iloc[query_image_index]['image_path']
print(f"Selected image for querying: {image_to_query}")

try:
    # Use the fine-tuned feature extractor and FAISS index
    results, distances = identify_artwork(
        image_to_query,
        model_extractor,
        index,
        image_df,
        val_transform,
        device,
        k=5
    )
except Exception as e:
    print(f"Error identifying artwork: {e}")
    results, distances = None, None

# Step 14: Display Top Matching Artworks
if results is not None and not results.empty:
    print("\nTop matching artworks:")
    for i, (idx, dist) in enumerate(zip(results.index, distances)):
        row = results.iloc[i]
        print(f"Rank {i+1}:")
        print(f"  Artwork Title: {row['title']}")
        print(f"  Artist: {row['artists']}")
        print(f"  Similarity Score: {dist:.4f}")
        print(f"  Image Path: {row['image_path']}\n")
else:
    print("No results to display due to an error in identifying the artwork.")

# Step 15: Display Results Visually
if results is not None and not results.empty:
    print("\nStep 15: Displaying results visually...")
    show_results(image_to_query, results)
else:
    print("Skipping visual display due to earlier errors.")


UnicodeDecodeError: 'utf-8' codec can't decode byte 0xdf in position 3: invalid continuation byte

In [2]:

print("Is CUDA available?:", torch.cuda.is_available())
print("Number of GPUs available:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("GPU Device Name:", torch.cuda.get_device_name(0))
else:
    print("CUDA is not available.")
print("PyTorch Version:", torch.__version__)
print("CUDA Version used by PyTorch:", torch.version.cuda)


NameError: name 'torch' is not defined