In [58]:
import torch

# Verify CUDA installation
print("=" * 50)
print("CUDA Verification")
print("=" * 50)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}")
print(f"cuDNN version: {torch.backends.cudnn.version() if torch.cuda.is_available() else 'N/A'}")
print(f"Number of GPUs: {torch.cuda.device_count() if torch.cuda.is_available() else 0}")

if torch.cuda.is_available():
    print(f"\nGPU Details:")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")
    
    # Quick test: create a tensor on GPU
    print(f"\nTesting GPU computation...")
    x = torch.randn(1000, 1000).cuda()
    y = torch.randn(1000, 1000).cuda()
    z = torch.matmul(x, y)
    print(f"‚úì GPU computation successful!")
    print(f"  Result tensor shape: {z.shape}")
    print(f"  Result tensor device: {z.device}")
else:
    print("\n‚ö† CUDA is not available. Check your installation.")


CUDA Verification
PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA version: 12.1
cuDNN version: 90100
Number of GPUs: 1

GPU Details:
  GPU 0: NVIDIA GeForce RTX 3050 Laptop GPU
    Memory: 4.00 GB

Testing GPU computation...
‚úì GPU computation successful!
  Result tensor shape: torch.Size([1000, 1000])
  Result tensor device: cuda:0


# Model Training Setup

Before we can use the model for predictions and interpretability, we need to train it on our art dataset.

## Training Strategy:
1. **Select subset of artists**: Take half of all available artists
2. **Split data**: Use 70% of selected artists' paintings for training, 15% for validation, 15% for testing
3. **Train CNN model**: Based on CIFAR-10 architecture, adapted for art classification


## Step 1: Dataset Preparation and Artist Selection


In [59]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import os
from pathlib import Path
import random
from collections import defaultdict

# Set random seeds for reproducibility
torch.manual_seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [60]:
# Custom Dataset class for art paintings
class ArtDataset(Dataset):
    def __init__(self, root_dir, artists=None, transform=None):
        """
        Args:
            root_dir: Root directory containing artist folders
            artists: List of artist names to include (None = all artists)
            transform: Optional transform to be applied on a sample
        """
        self.root_dir = Path(root_dir)
        self.transform = transform
        
        # If artists not specified, discover all artist folders
        if artists is None:
            self.artists = sorted([d.name for d in self.root_dir.iterdir() if d.is_dir()])
        else:
            self.artists = artists
        
        # Create artist to index mapping
        self.artist_to_idx = {artist: idx for idx, artist in enumerate(self.artists)}
        self.idx_to_artist = {idx: artist for artist, idx in self.artist_to_idx.items()}
        
        # Load all image paths with their labels
        self.samples = []
        for artist in self.artists:
            artist_dir = self.root_dir / artist
            if artist_dir.exists():
                # Support common image formats
                image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
                for img_path in artist_dir.iterdir():
                    if img_path.suffix.lower() in image_extensions:
                        self.samples.append((str(img_path), self.artist_to_idx[artist]))
        
        print(f"Loaded {len(self.samples)} images from {len(self.artists)} artists")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Load image (you may need to install Pillow: pip install Pillow)
        from PIL import Image
        try:
            image = Image.open(img_path).convert('RGB')
            # Note: Original image size may vary, but transform will resize to 224x224
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a black image as fallback
            image = Image.new('RGB', (224, 224), color='black')
        
        # Apply transforms (which includes resizing to 224x224)
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
    def get_artist_counts(self):
        """Get count of images per artist"""
        counts = defaultdict(int)
        for _, label in self.samples:
            artist = self.idx_to_artist[label]
            counts[artist] += 1
        return dict(counts)


In [61]:
# Define data transforms
# IMPORTANT: All images will be resized to 224x224 regardless of original size
# ResNet expects 224x224 input (ImageNet standard)
TARGET_SIZE = (224, 224)

# Training transforms: Enhanced augmentation for art classification
# Following best practices for transfer learning with ResNet
train_transform = transforms.Compose([
    transforms.Resize(256),  # Resize to slightly larger first
    transforms.RandomCrop(TARGET_SIZE),  # Random crop to 224x224
    transforms.RandomHorizontalFlip(p=0.5),  # Horizontal flip
    transforms.RandomRotation(15),  # Rotation up to 15 degrees
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),  # Color augmentation
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Slight translation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

# Validation/Test transforms: No augmentation, just resize and normalize
val_test_transform = transforms.Compose([
    transforms.Resize(256),  # Resize to slightly larger
    transforms.CenterCrop(TARGET_SIZE),  # Center crop to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

print(f"‚úÖ Transforms configured for ResNet (224x224 input)")
print(f"  Training: Enhanced augmentation (rotation, color jitter, random crop)")
print(f"  Test: Center crop only (no augmentation)")


‚úÖ Transforms configured for ResNet (224x224 input)
  Training: Enhanced augmentation (rotation, color jitter, random crop)
  Test: Center crop only (no augmentation)


In [62]:
# Load full dataset
# TODO: Update this path to your actual dataset directory
# Expected structure: dataset_root/artist_name/image1.jpg, image2.jpg, ...
DATASET_ROOT = "art Folder/images/images"  # CHANGE THIS!
import os
# Check if dataset exists
if not os.path.exists(DATASET_ROOT):
    print(f"‚ö†Ô∏è  Dataset not found at: {DATASET_ROOT}")
    print("Please update DATASET_ROOT with your actual dataset path")
    print("\nExpected directory structure:")
    print("dataset_root/")
    print("  ‚îú‚îÄ‚îÄ Picasso/")
    print("  ‚îÇ   ‚îú‚îÄ‚îÄ painting1.jpg")
    print("  ‚îÇ   ‚îî‚îÄ‚îÄ painting2.jpg")
    print("  ‚îú‚îÄ‚îÄ Matisse/")
    print("  ‚îÇ   ‚îî‚îÄ‚îÄ ...")
    print("  ‚îî‚îÄ‚îÄ ...")
else:
    # 1) Load all artists from the full dataset (no transform yet)
    full_dataset = ArtDataset(DATASET_ROOT, transform=None)
    all_artists = full_dataset.artists

    print(f"\nüìä Found {len(all_artists)} artists in dataset:")
    artist_counts = full_dataset.get_artist_counts()
    for artist, count in sorted(artist_counts.items(), key=lambda x: x[1], reverse=True):
        print(f"  {artist}: {count} images")

    # 2) Select half of the artists (you can change this strategy if you like)
    num_artists_to_use = len(all_artists) // 2
    selected_artists = sorted(all_artists)[:num_artists_to_use]

    print(f"\nüé® Selected {len(selected_artists)} artists for training:")
    for artist in selected_artists:
        print(f"  - {artist}")

    # 3) Base dataset restricted to selected artists (no transform yet)
    base_dataset = ArtDataset(DATASET_ROOT, artists=selected_artists, transform=None)

    # 4) Split into train (70%), test (30%)
    total_size = len(base_dataset)
    train_size = int(0.70 * total_size)
    test_size = total_size - train_size  # Remaining 30% goes to test

    train_subset, test_subset = random_split(
        base_dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )

    # Extract indices from the subsets
    train_indices = train_subset.indices
    test_indices = test_subset.indices

    # 5) Rebuild datasets with different transforms for each split
    #    (we assume train_transform and val_test_transform are defined earlier)
    train_dataset = torch.utils.data.Subset(
        ArtDataset(DATASET_ROOT, artists=selected_artists, transform=train_transform),
        train_indices
    )
    test_dataset = torch.utils.data.Subset(
        ArtDataset(DATASET_ROOT, artists=selected_artists, transform=val_test_transform),
        test_indices
    )

    print(f"\nüì¶ Dataset splits:")
    print(f"  Training: {len(train_dataset)} images ({len(train_dataset)/total_size*100:.1f}%)")
    print(f"  Test: {len(test_dataset)} images ({len(test_dataset)/total_size*100:.1f}%)")

    # 6) Create data loaders
    BATCH_SIZE = 32

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0,  # use 0 to avoid multiprocessing issues while debugging
        pin_memory=(device.type == "cuda")
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=0,
        pin_memory=(device.type == "cuda")
    )

    num_classes = len(selected_artists)
    print(f"\n‚úÖ Dataset ready! Number of classes: {num_classes}")

    # 7) Verify image sizes are correct
    print("\nüîç Verifying image resizing...")
    sample_image, sample_label = next(iter(train_loader))
    print(f"  Sample batch shape: {sample_image.shape}")
    print("  Expected shape: [batch_size, 3, 224, 224]")
    if sample_image.shape[2] == 224 and sample_image.shape[3] == 224:
        print("  ‚úÖ All images are correctly resized to 224x224!")
    else:
        print(f"  ‚ö†Ô∏è  Warning: Images are not 224x224! "
              f"Actual size: {sample_image.shape[2]}x{sample_image.shape[3]}")


Loaded 8774 images from 51 artists

üìä Found 51 artists in dataset:
  Vincent_van_Gogh: 877 images
  Edgar_Degas: 702 images
  Pablo_Picasso: 439 images
  Pierre-Auguste_Renoir: 336 images
  Albrecht_DuŒì√≤√°‚îú¬¨rer: 328 images
  Albrecht_Du‚ï†√™rer: 328 images
  Paul_Gauguin: 311 images
  Francisco_Goya: 291 images
  Rembrandt: 262 images
  Alfred_Sisley: 259 images
  Titian: 255 images
  Marc_Chagall: 239 images
  Rene_Magritte: 194 images
  Amedeo_Modigliani: 193 images
  Paul_Klee: 188 images
  Henri_Matisse: 186 images
  Andy_Warhol: 181 images
  Mikhail_Vrubel: 171 images
  Sandro_Botticelli: 164 images
  Leonardo_da_Vinci: 143 images
  Peter_Paul_Rubens: 141 images
  Salvador_Dali: 139 images
  Hieronymus_Bosch: 137 images
  Pieter_Bruegel: 134 images
  Diego_Velazquez: 128 images
  Kazimir_Malevich: 126 images
  Frida_Kahlo: 120 images
  Giotto_di_Bondone: 119 images
  Gustav_Klimt: 117 images
  Raphael: 109 images
  Joan_Miro: 102 images
  Andrei_Rublev: 99 images
  Camille

## Model Architecture: Transfer Learning with ResNet50

**Improvements for Better Accuracy:**

1. **Transfer Learning**: Using pretrained ResNet50 (trained on ImageNet) instead of training from scratch
   - ResNet50 has learned rich visual features that transfer well to art classification
   - Much faster convergence and higher accuracy

2. **Enhanced Data Augmentation**:
   - Random crop, rotation, color jitter, and translation
   - Helps model generalize better to different art styles and orientations

3. **Differential Learning Rates**:
   - Lower LR (0.0001) for pretrained backbone layers
   - Higher LR (0.001) for new classifier layers
   - Allows fine-tuning without destroying pretrained features

4. **Better Architecture**:
   - ResNet50 backbone (50 layers, 2048 features)
   - Custom classifier with embedding layer for similarity search
   - Dropout for regularization

This approach follows best practices from art classification research and should significantly improve accuracy!


## Step 2: Model Architecture

Based on CIFAR-10 tutorial, adapted for art classification with embedding support.


In [63]:
import torchvision.models as models

class ArtClassifier(nn.Module):
    """
    Art classification model using transfer learning with ResNet50
    Based on best practices for art classification tasks
    Modified to support embedding extraction for similarity search
    """
    def __init__(self, num_classes, use_pretrained=True):
        super(ArtClassifier, self).__init__()
        
        # Load pretrained ResNet50
        resnet = models.resnet50(pretrained=use_pretrained)
        
        # Freeze early layers for transfer learning (optional - can unfreeze later)
        # for param in list(resnet.parameters())[:-10]:
        #     param.requires_grad = False
        
        # Replace the final fully connected layer
        # ResNet50's fc layer expects 2048 features (from avgpool)
        num_features = resnet.fc.in_features
        
        # Remove the original classifier
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # Remove final fc layer
        
        # Add custom classifier with embedding layer
        self.fc1 = nn.Linear(num_features, 512)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 256)  # Embedding layer for similarity search
        self.dropout2 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(256, num_classes)  # Final classification layer
        
    def forward(self, x, return_embedding=False):
        # Extract features using ResNet backbone
        x = self.backbone(x)
        # Flatten: ResNet avgpool outputs [batch_size, 2048, 1, 1]
        x = x.view(x.size(0), -1)  # [batch_size, 2048]
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        embedding = F.relu(self.fc2(x))  # Embedding for similarity search
        x = self.dropout2(embedding)
        logits = self.fc3(x)  # Classification logits
        
        if return_embedding:
            return logits, embedding
        return logits

# Initialize model
if 'num_classes' in locals():
    model = ArtClassifier(num_classes=num_classes).to(device)
    print(f"\n‚úÖ Model initialized with {num_classes} classes")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
else:
    print("‚ö†Ô∏è  Please run the dataset preparation cell first!")



‚úÖ Model initialized with 25 classes
Model parameters: 24,694,873


## Step 3: Training Setup


In [64]:
# Training hyperparameters for transfer learning
# Use lower LR for pretrained backbone, higher for new classifier
BACKBONE_LR = 0.0001  # Lower LR for pretrained ResNet layers
CLASSIFIER_LR = 0.001  # Higher LR for new classifier layers
NUM_EPOCHS = 30  # More epochs for fine-tuning
WEIGHT_DECAY = 1e-4

# Loss function
criterion = nn.CrossEntropyLoss()

# Initialize optimizer and scheduler (after model is created)
if 'model' in locals():
    # Use different learning rates for backbone and classifier
    # This is a common practice in transfer learning
    optimizer = optim.Adam([
        {'params': model.backbone.parameters(), 'lr': BACKBONE_LR},
        {'params': [p for n, p in model.named_parameters() if 'backbone' not in n], 'lr': CLASSIFIER_LR}
    ], weight_decay=WEIGHT_DECAY)
    
    # Use cosine annealing for smoother learning rate decay
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
    
    print("Training configuration (Transfer Learning):")
    print(f"  Backbone LR (pretrained): {BACKBONE_LR}")
    print(f"  Classifier LR (new layers): {CLASSIFIER_LR}")
    print(f"  Epochs: {NUM_EPOCHS}")
    print(f"  Batch size: {BATCH_SIZE}")
    print(f"  Optimizer: Adam (with different LRs)")
    print(f"  Scheduler: CosineAnnealingLR")
    print(f"  Loss: CrossEntropyLoss")
    print(f"  Device: {device}")
else:
    print("‚ö†Ô∏è  Please run the model initialization cell first!")


Training configuration (Transfer Learning):
  Backbone LR (pretrained): 0.0001
  Classifier LR (new layers): 0.001
  Epochs: 30
  Batch size: 32
  Optimizer: Adam (with different LRs)
  Scheduler: CosineAnnealingLR
  Loss: CrossEntropyLoss
  Device: cuda


In [65]:
# Check for old model file and handle architecture mismatch
model_file = 'art_classifier_model.pth'
if os.path.exists(model_file):
    try:
        # Try to load and check if it's compatible
        checkpoint = torch.load(model_file, map_location=device, weights_only=False)
        saved_keys = set(checkpoint.get('model_state_dict', {}).keys())
        
        # Check if it's the old architecture (has conv1, conv2, etc.) or new (has backbone)
        if 'conv1.weight' in saved_keys and 'backbone.0.weight' not in saved_keys:
            print("‚ö†Ô∏è  Found old model file with incompatible architecture (custom CNN).")
            print("   Backing up old model and will train new ResNet50 model...")
            import shutil
            backup_name = 'art_classifier_model_old_CNN.pth'
            if not os.path.exists(backup_name):
                shutil.move(model_file, backup_name)
                print(f"   ‚úÖ Old model backed up to: {backup_name}")
            else:
                os.remove(model_file)
                print(f"   ‚úÖ Old model removed (backup already exists)")
        else:
            print(f"‚úÖ Found compatible model file: {model_file}")
    except Exception as e:
        print(f"‚ö†Ô∏è  Error checking model file: {e}")
        print("   Will start fresh training...")
else:
    print("‚ÑπÔ∏è  No existing model file found. Will train new model.")


‚ö†Ô∏è  Found old model file with incompatible architecture (custom CNN).
   Backing up old model and will train new ResNet50 model...
   ‚úÖ Old model backed up to: art_classifier_model_old_CNN.pth


## Step 4: Training Loop


In [66]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, test_loader, criterion, device):
    """Evaluate the model on test set"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(test_loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc


In [67]:
images, labels = next(iter(train_loader))
images = images.to(device)

print("Images shape:", images.shape)  # expect [B, 3, H, W]
with torch.no_grad():
    outputs = model(images)
print("Forward pass OK, output shape:", outputs.shape)

Images shape: torch.Size([32, 3, 224, 224])
Forward pass OK, output shape: torch.Size([32, 25])


In [None]:
# Training loop
if 'train_loader' in locals() and 'val_loader' in locals() and 'optimizer' in locals():
    # if the model is already trained, skip the training loop
    # check if the path 'art_classifier_model.pth' exists
    if os.path.exists('art_classifier_model.pth'):
        print("üöÄ Model already trained! Skipping training...")
        # load the model from the path 'art_classifier_model.pth'
        model.load_state_dict(torch.load('art_classifier_model.pth')['model_state_dict'])
        
    else:
        print("üöÄ Starting training...\n")
        
        train_losses, train_accs = [], []
        val_losses, val_accs = [], []
        best_val_acc = 0.0
        best_model_state = None
        
        for epoch in range(NUM_EPOCHS):
            print(f"Epoch {epoch+1} of {NUM_EPOCHS}")
            # Train
            train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
            train_losses.append(train_loss)
            train_accs.append(train_acc)
            
            # Validate
            val_loss, val_acc = validate(model, val_loader, criterion, device)
            val_losses.append(val_loss)
            val_accs.append(val_acc)
            
            # Update learning rate
            scheduler.step()
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model_state = model.state_dict().copy()
            
            # Print progress
            print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")
            print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            print(f"  LR: {scheduler.get_last_lr()[0]:.6f}")
            print()
        
    # Load best model
        model.load_state_dict(best_model_state)
        print(f"‚úÖ Training complete! Best validation accuracy: {best_val_acc:.2f}%")
        
        # Test on test set
        print("\nüìä Evaluating on test set...")
        test_loss, test_acc = validate(model, test_loader, criterion, device)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")
        
        # Save model
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'num_classes': num_classes,
            'artists': selected_artists,
            'test_acc': test_acc,
        }, 'art_classifier_model.pth')
        print("\nüíæ Model saved to 'art_classifier_model.pth'")
        
else:
    print("‚ö†Ô∏è  Please run the dataset preparation cell first!")


üöÄ Starting training...

Epoch 1 of 30
Epoch [1/30]
  Train Loss: 2.3196, Train Acc: 33.62%
  Val Loss: 1.6919, Val Acc: 46.72%
  LR: 0.000100

Epoch 2 of 30
Error loading image art Folder\images\images\Albrecht_DuŒì√≤√°‚îú¬¨rer\Albrecht_DuŒì√≤√°‚îú¬¨rer_181.jpg: [Errno 2] No such file or directory: 'art Folder\\images\\images\\Albrecht_DuŒì√≤√°‚îú¬¨rer\\Albrecht_DuŒì√≤√°‚îú¬¨rer_181.jpg'
Error loading image art Folder\images\images\Albrecht_Du‚ï†√™rer\Albrecht_Du‚ï†√™rer_8.jpg: [Errno 2] No such file or directory: 'art Folder\\images\\images\\Albrecht_Du‚ï†√™rer\\Albrecht_Du‚ï†√™rer_8.jpg'
Error loading image art Folder\images\images\Andy_Warhol\Andy_Warhol_175.jpg: [Errno 2] No such file or directory: 'art Folder\\images\\images\\Andy_Warhol\\Andy_Warhol_175.jpg'
Error loading image art Folder\images\images\Albrecht_Du‚ï†√™rer\Albrecht_Du‚ï†√™rer_207.jpg: [Errno 2] No such file or directory: 'art Folder\\images\\images\\Albrecht_Du‚ï†√™rer\\Albrecht_Du‚ï†√™rer_207.jpg'
Error loadin

## Training Summary

After training, you'll have:
- ‚úÖ Trained model saved to `art_classifier_model.pth`
- ‚úÖ Model ready for predictions on selected artists
- ‚úÖ Embedding layer available for similarity search
- ‚úÖ Model ready for interpretability analysis (Captum)

**Next steps:** Use the trained model with the interpretability features described in the sections below!


# Artfluence: Art Classification System Architecture

This notebook breaks down how to extend the CIFAR-10 tutorial to build a comprehensive art classification system with interpretability features.

## Overview

The CIFAR-10 tutorial provides the **core engine** (a CNN that outputs logits for classes). Everything else is additional layers of logic on top of that foundation.


## Part 1: What the CIFAR-10 Tutorial Already Gives You

The CIFAR-10 tutorial trains a CNN, gets raw outputs ("energies") for each class, and picks the argmax as the predicted label.

### ‚úÖ Directly Supported Features


### 1. Predicted Artist

**What it is:** The class with the highest logit value.

**How to get it:**
- Replace `CIFAR-10 classes = ['airplane', 'car', ...]` with `artists = ['Picasso', 'Matisse', ...]`
- Use `torch.max(outputs, 1)` to get the predicted class

**Code:**


In [None]:
# Example: Getting predicted artist from model outputs
# Assuming you have a trained model and an input image

# artists = ['Picasso', 'Matisse', 'Van Gogh', 'Monet', ...]
# outputs = model(input_image)  # Shape: [batch_size, num_artists]

# Get predicted artist (argmax)
_, predicted = torch.max(outputs, 1)
# predicted contains the index of the predicted artist

# Convert to artist name
# predicted_artist = artists[predicted.item()]


### 2. Influence Distribution

**What it is:** Probability distribution over all artists showing how much each artist influenced the prediction.

**How to get it:**
- Apply softmax to the logits to convert them to probabilities
- This gives you a probability for every artist = an "influence distribution"

**Code:**


In [None]:

# Convert logits to probability distribution
probs = torch.softmax(outputs, dim=1)

# probs now contains probabilities for each artist
# Example output shape: [batch_size, num_artists]
# Each row sums to 1.0

# You can visualize this as:
# - A bar chart showing probability for each artist
# - A sorted list of artists by influence



### 3. Confidence / Uncertainty Estimates

**What it is:** 
- **Confidence**: How sure the model is about its prediction
- **Uncertainty**: How uncertain the model is (opposite of confidence)

**How to get it:**
- **Confidence** = `max(probs)` for the chosen artist
- **Uncertainty** = entropy of probs or `1 - confidence`

**Code:**


In [None]:
import torch.nn.functional as F
import pandas as pd
artists = pd.read_csv('art Folder/artists.csv')

# Get confidence (probability of predicted class)
# Note: probs can be a batch, so confidence will be a tensor with shape [batch_size]
confidence = torch.max(probs, dim=1)[0]  # Max probability

# Get uncertainty using entropy
# Higher entropy = more uncertainty (probabilities are spread out)
# Lower entropy = less uncertainty (one probability dominates)
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=1)
uncertainty = entropy / torch.log(torch.tensor(len(artists), dtype=torch.float32))  # Normalized

# Alternative: Simple uncertainty as 1 - confidence
uncertainty_simple = 1 - confidence

# Handle both single samples and batches
if confidence.numel() == 1:
    # Single sample
    print(f"Confidence: {confidence.item():.3f}")
    print(f"Uncertainty (entropy): {uncertainty.item():.3f}")
    print(f"Uncertainty (simple): {uncertainty_simple.item():.3f}")
else:
    # Batch - show statistics
    print(f"Batch size: {confidence.shape[0]}")
    print(f"Confidence - Mean: {confidence.mean().item():.3f}, Min: {confidence.min().item():.3f}, Max: {confidence.max().item():.3f}")
    print(f"Uncertainty (entropy) - Mean: {uncertainty.mean().item():.3f}, Min: {uncertainty.min().item():.3f}, Max: {uncertainty.max().item():.3f}")
    print(f"Uncertainty (simple) - Mean: {uncertainty_simple.mean().item():.3f}, Min: {uncertainty_simple.min().item():.3f}, Max: {uncertainty_simple.max().item():.3f}")
    
    # Show first sample in batch as example
    print(f"\nFirst sample in batch:")
    print(f"  Confidence: {confidence[0].item():.3f}")
    print(f"  Uncertainty (entropy): {uncertainty[0].item():.3f}")
    print(f"  Uncertainty (simple): {uncertainty_simple[0].item():.3f}")


Batch size: 32
Confidence - Mean: 0.043, Min: 0.042, Max: 0.043
Uncertainty (entropy) - Mean: 0.823, Min: 0.823, Max: 0.823
Uncertainty (simple) - Mean: 0.957, Min: 0.957, Max: 0.958

First sample in batch:
  Confidence: 0.043
  Uncertainty (entropy): 0.823
  Uncertainty (simple): 0.957


## Part 2: Features Requiring Small Design Extensions

These features aren't in the CIFAR-10 tutorial but can be added with standard techniques.


### 4. Top-K Nearest Paintings

**What it is:** Find the k most similar paintings in your database to the query painting.

**How to implement:**
1. Modify the model's `forward()` method to return embeddings from the penultimate layer (e.g., `fc2`)
2. Pre-compute and store embeddings for all paintings in your database
3. For a query painting, compute its embedding and do k-NN search (cosine or Euclidean distance)

**Code:**


In [38]:
# Step 1: Modify model to return embeddings
import torch.nn as nn
import torch.nn.functional as F

class ArtClassifier(nn.Module):
    def __init__(self, num_artists):
        super().__init__()
        # ... CNN layers ...
        self.fc1 = nn.Linear(14 * 14 * 512, 512)  # First FC layer after CNN
        self.fc2 = nn.Linear(512, 256)  # Penultimate layer (embedding)
        self.fc3 = nn.Linear(256, num_artists)  # Final classification layer
    
    def forward(self, x, return_embedding=False):
        # ... CNN forward pass (conv layers, pooling, flattening) ...
        # After CNN: x has shape [batch_size, 14*14*512]
        features = F.relu(self.fc1(x))  # Process through first FC layer
        embedding = F.relu(self.fc2(features))  # Get embedding from penultimate layer
        logits = self.fc3(embedding)  # Get logits from final layer
        
        if return_embedding:
            return logits, embedding
        return logits

# Step 2: Pre-compute embeddings for database
def build_embedding_database(model, dataloader, device):
    """Pre-compute embeddings for all paintings in database"""
    model.eval()
    embeddings = []
    painting_ids = []
    
    with torch.no_grad():
        for images, ids in dataloader:
            images = images.to(device)
            _, emb = model(images, return_embedding=True)
            embeddings.append(emb.cpu())
            painting_ids.extend(ids)
    
    embeddings = torch.cat(embeddings, dim=0)
    return embeddings, painting_ids

# Step 3: Find top-k nearest paintings
def find_top_k_nearest(query_embedding, database_embeddings, k=5, metric='cosine'):
    """Find k nearest paintings using cosine or Euclidean distance"""
    if metric == 'cosine':
        # Normalize embeddings
        query_norm = F.normalize(query_embedding, p=2, dim=1)
        db_norm = F.normalize(database_embeddings, p=2, dim=1)
        # Compute cosine similarity
        similarities = torch.mm(query_norm, db_norm.t())
        top_k_values, top_k_indices = torch.topk(similarities, k, dim=1)
    else:  # Euclidean
        distances = torch.cdist(query_embedding, database_embeddings)
        top_k_values, top_k_indices = torch.topk(-distances, k, dim=1)  # Negative for top-k
    
    return top_k_indices, top_k_values


In [43]:
# Simple threshold-based approach
def is_unknown_artist(probs, confidence_threshold=0.3):
    """
    Flag as unknown if confidence is too low
    Returns: tensor of booleans (one per sample in batch)
    """
    confidence = torch.max(probs, dim=1)[0]
    return confidence < confidence_threshold

# Example usage
probs = torch.softmax(outputs, dim=1)
unknown_flags = is_unknown_artist(probs, confidence_threshold=0.3)

# Handle both single samples and batches
if unknown_flags.numel() == 1:
    # Single sample
    if unknown_flags.item():
        print("Flagged as: Unknown Artist")
    else:
        predicted_idx = torch.argmax(probs, dim=1)
        print(f"Predicted: {artists[predicted_idx.item()]}")
else:
    # Batch - process each sample
    predicted_indices = torch.argmax(probs, dim=1)
    for i in range(len(unknown_flags)):
        if unknown_flags[i].item():
            print(f"Sample {i}: Flagged as Unknown Artist")
        else:
            print(f"Sample {i}: Predicted {artists[predicted_indices[i].item()]}")

# Better: OOD detection using embedding distance
def is_unknown_ood(query_embedding, database_embeddings, threshold_percentile=95):
    """
    Flag as unknown if query embedding is far from all known artist embeddings
    Uses percentile of distances as threshold
    """
    # Compute distances to all database embeddings
    distances = torch.cdist(query_embedding, database_embeddings)
    min_distance = torch.min(distances, dim=1)[0]
    
    # Threshold: if min distance > 95th percentile of all distances, it's OOD
    threshold = torch.quantile(distances, threshold_percentile / 100.0)
    
    return min_distance > threshold


Sample 0: Flagged as Unknown Artist
Sample 1: Flagged as Unknown Artist
Sample 2: Flagged as Unknown Artist
Sample 3: Flagged as Unknown Artist
Sample 4: Flagged as Unknown Artist
Sample 5: Flagged as Unknown Artist
Sample 6: Flagged as Unknown Artist
Sample 7: Flagged as Unknown Artist
Sample 8: Flagged as Unknown Artist
Sample 9: Flagged as Unknown Artist
Sample 10: Flagged as Unknown Artist
Sample 11: Flagged as Unknown Artist
Sample 12: Flagged as Unknown Artist
Sample 13: Flagged as Unknown Artist
Sample 14: Flagged as Unknown Artist
Sample 15: Flagged as Unknown Artist
Sample 16: Flagged as Unknown Artist
Sample 17: Flagged as Unknown Artist
Sample 18: Flagged as Unknown Artist
Sample 19: Flagged as Unknown Artist
Sample 20: Flagged as Unknown Artist
Sample 21: Flagged as Unknown Artist
Sample 22: Flagged as Unknown Artist
Sample 23: Flagged as Unknown Artist
Sample 24: Flagged as Unknown Artist
Sample 25: Flagged as Unknown Artist
Sample 26: Flagged as Unknown Artist
Sample 27: 

## Part 3: Interpretability Features (Requires Additional Tooling)

These features go beyond the CIFAR-10 tutorial and require **Captum** (PyTorch's interpretability library) and additional image analysis code.


### 6. Per-Artist Factor Explanation

**What it is:** Understand which visual regions/features the model focuses on when predicting a specific artist.

**How to implement:**
- Use **Captum** methods like Integrated Gradients, Grad-CAM, or Guided Backpropagation
- Generate heatmaps showing which pixels/regions contributed most to the prediction
- Summarize as: *"For Picasso, the model focused on high-contrast angular shapes in the upper left and bold outlines in faces"*

**Code:**


In [None]:
# Install Captum: pip install captum

from captum.attr import IntegratedGradients, GradCAM, GuidedBackprop
from captum.attr import visualization as viz

# Initialize attribution methods
integrated_gradients = IntegratedGradients(model)
grad_cam = GradCAM(model, model.layer4)  # Use appropriate layer
guided_backprop = GuidedBackprop(model)

# Get attributions for a specific artist prediction
def get_artist_attributions(model, input_image, target_artist_idx):
    """
    Get attribution maps showing which pixels contributed to predicting a specific artist
    """
    model.eval()
    
    # Method 1: Integrated Gradients
    attributions_ig = integrated_gradients.attribute(
        input_image, 
        target=target_artist_idx,
        n_steps=50
    )
    
    # Method 2: Grad-CAM (class activation maps)
    attributions_gradcam = grad_cam.attribute(
        input_image,
        target=target_artist_idx
    )
    
    # Method 3: Guided Backpropagation
    attributions_gbp = guided_backprop.attribute(
        input_image,
        target=target_artist_idx
    )
    
    return {
        'integrated_gradients': attributions_ig,
        'gradcam': attributions_gradcam,
        'guided_backprop': attributions_gbp
    }

# Visualize attributions
def visualize_attributions(input_image, attributions, artist_name):
    """Visualize which regions contributed to the artist prediction"""
    # Convert to numpy for visualization
    input_np = input_image.squeeze().cpu().detach().numpy().transpose(1, 2, 0)
    attr_np = attributions.squeeze().cpu().detach().numpy()
    
    # Use Captum's visualization
    viz.visualize_image_attr(
        attr_np,
        input_np,
        method="heat_map",
        sign="all",
        title=f"Attribution for {artist_name}"
    )

# Example usage
# predicted_idx = torch.argmax(outputs, dim=1).item()
# attributions = get_artist_attributions(model, input_image, predicted_idx)
# visualize_attributions(input_image, attributions['gradcam'], artists[predicted_idx])


### 7. Explanation of Visual Elements (Colors, Texture, Brush Strokes)

**What it is:** Explain predictions in terms of human-understandable visual features like colors, brush stroke thickness, and texture.

**Reality Check:**
- **Colors**: Easy to quantify (RGB channels, color histograms, attribution across color channels)
- **Texture/Brush Strokes**: Detectable qualitatively via CNN filters and attribution maps, but won't output exact "brush thickness = 4px"

**How to implement:**
1. Combine attribution maps with image statistics
2. Analyze which color channels/regions get high attribution
3. Inspect early conv filters that respond to edges/strokes
4. Generate human-readable descriptions

**Code:**


In [None]:
import numpy as np
from PIL import Image
import cv2

def analyze_colors(input_image, attributions):
    """
    Analyze which colors contributed most to the prediction
    """
    # Convert to numpy
    img_np = input_image.squeeze().cpu().detach().numpy().transpose(1, 2, 0)
    attr_np = attributions.squeeze().cpu().detach().numpy()
    
    # Get color channels
    r, g, b = img_np[:, :, 0], img_np[:, :, 1], img_np[:, :, 2]
    
    # Weight by attributions
    r_weighted = np.sum(r * attr_np)
    g_weighted = np.sum(g * attr_np)
    b_weighted = np.sum(b * attr_np)
    
    # Determine dominant color influence
    total = r_weighted + g_weighted + b_weighted
    color_contributions = {
        'red': r_weighted / total,
        'green': g_weighted / total,
        'blue': b_weighted / total
    }
    
    # Convert to hue description
    dominant_hue = max(color_contributions, key=color_contributions.get)
    
    return {
        'dominant_color': dominant_hue,
        'contributions': color_contributions,
        'description': f"Model focused on {dominant_hue} tones"
    }

def analyze_texture_brush_strokes(input_image, attributions):
    """
    Analyze texture and brush stroke patterns (qualitative)
    """
    img_np = input_image.squeeze().cpu().detach().numpy().transpose(1, 2, 0)
    attr_np = attributions.squeeze().cpu().detach().numpy()
    
    # Convert to grayscale for texture analysis
    gray = cv2.cvtColor((img_np * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
    
    # Edge detection (rough proxy for brush strokes)
    edges = cv2.Canny(gray, 50, 150)
    edge_density = np.sum(edges > 0) / edges.size
    
    # Weight edges by attributions
    attr_resized = cv2.resize(attr_np, (gray.shape[1], gray.shape[0]))
    edge_attribution = np.sum(edges * attr_resized) / np.sum(attr_resized + 1e-10)
    
    # Texture analysis using local variance
    kernel = np.ones((5, 5), np.float32) / 25
    local_mean = cv2.filter2D(gray.astype(np.float32), -1, kernel)
    local_var = cv2.filter2D((gray.astype(np.float32) - local_mean)**2, -1, kernel)
    texture_roughness = np.mean(local_var)
    
    # Generate description
    if edge_attribution > 0.5 and texture_roughness > 1000:
        description = "Energetic, visible brush strokes with high texture"
    elif edge_attribution > 0.3:
        description = "Moderate brush stroke visibility"
    else:
        description = "Smooth, fine brushwork"
    
    return {
        'edge_density': edge_density,
        'texture_roughness': texture_roughness,
        'brush_stroke_attribution': edge_attribution,
        'description': description
    }

def generate_visual_explanation(input_image, attributions, artist_name):
    """
    Generate human-readable explanation combining all visual factors
    """
    color_analysis = analyze_colors(input_image, attributions)
    texture_analysis = analyze_texture_brush_strokes(input_image, attributions)
    
    explanation = f"""
    For {artist_name}:
    - Color: {color_analysis['description']}
    - Texture: {texture_analysis['description']}
    - Dominant color influence: {color_analysis['dominant_color']}
    """
    
    return explanation

# Example usage
# attributions = get_artist_attributions(model, input_image, predicted_idx)['gradcam']
# explanation = generate_visual_explanation(input_image, attributions, artists[predicted_idx])
# print(explanation)


## Part 4: Architecture Overview

### Complete Artfluence Pipeline

```
Input Painting
    ‚Üì
[CNN Feature Extraction]
    ‚Üì
[Penultimate Layer (fc2)] ‚Üí Embeddings ‚Üí Top-K Nearest Paintings
    ‚Üì
[Final Layer (fc3)] ‚Üí Logits
    ‚Üì
[Softmax] ‚Üí Probabilities (Influence Distribution)
    ‚Üì
[Argmax] ‚Üí Predicted Artist
    ‚Üì
[Confidence Check] ‚Üí Unknown Artist Flag (if confidence < threshold)
    ‚Üì
[Captum Attribution] ‚Üí Per-Artist Factor Explanations
    ‚Üì
[Visual Analysis] ‚Üí Color/Texture/Brush Stroke Explanations
```

### Key Components:

1. **Core Model**: CNN architecture from CIFAR-10 tutorial (modified for artists)
2. **Embedding Layer**: Penultimate layer outputs for similarity search
3. **Classification Layer**: Final layer outputs logits for artist prediction
4. **Post-Processing**: Softmax, confidence, uncertainty calculations
5. **Interpretability**: Captum for attribution maps
6. **Visual Analysis**: Custom functions for color/texture analysis


## Summary: What's Possible Based on CIFAR-10 Tutorial

### ‚úÖ Directly Supported (from tutorial):
- Train a CNN on images
- Get logits, pick predicted class (artist)
- Turn logits into probability distribution (influence distribution) with softmax
- Derive confidence/uncertainty from probabilities

### ‚úÖ Requires Small Extensions:
- **Top-k nearest paintings**: Use penultimate-layer embeddings + k-NN search
- **Unknown artist flag**: Thresholding or OOD detection on logits/embeddings

### ‚ö†Ô∏è Requires Additional Tooling (Captum + custom code):
- **Per-artist factor explanations**: Feature attribution methods (Grad-CAM, Integrated Gradients)
- **Visual element explanations**: Colors (easy), textures/brush strokes (qualitative inference)

### üìù Notes:
- The CIFAR-10 tutorial gives you the **right architectural starting point**
- It does **not** implement all interpretability features by itself
- Those features are **doable** but require the next layer: **Captum + your own analysis code**
- Brush stroke "thickness" won't be perfect‚Äîit's qualitative inference from texture/edge patterns


## Next Steps

1. **Install Captum**: `pip install captum`
2. **Modify CIFAR-10 model** to return embeddings
3. **Build embedding database** for your painting collection
4. **Implement k-NN search** for top-k nearest paintings
5. **Add Captum attribution** methods for interpretability
6. **Create visual analysis functions** for colors/textures
7. **Combine everything** into the complete Artfluence pipeline
