### Stable Diffusion-based Image Classifier Fine-Tuning with Vision Transformer as Feature Extractor

#### 0. Setup Environment

In [None]:
import torch
from sympy.strategies.core import switch

print(torch.cuda.is_available())
# Get the number of GPUs
print(torch.cuda.device_count())
# Get the current GPU device
print(torch.cuda.current_device())
print(torch.cuda.get_device_name())

#### 1. Load Modified Mini-GCD Dataset

In [None]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from transformers import ViTForImageClassification, ViTFeatureExtractor
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay
import random

# Check GPU availability
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name())

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# Load Dataset
data_dir = "modified-mini-GCD"
train_dir = os.path.join(data_dir, "train")
test_dir = os.path.join(data_dir, "test")

train_dataset = ImageFolder(root=train_dir, transform=transform)
test_dataset = ImageFolder(root=test_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Class Names
class_names = train_dataset.classes
print(f"Classes: {class_names}")

# Show number of training samples per class
train_class_counts = {class_names[i]: 0 for i in range(len(class_names))}
for _, label in train_dataset:
    train_class_counts[class_names[label]] += 1
print("Train Class Counts:", train_class_counts)

#### 1.1. Display Sample Images

In [None]:
# Display Sample Images
def show_images(dataloader, class_names):
    images, labels = next(iter(dataloader))
    fig, axes = plt.subplots(1, len(images), figsize=(15, 5))
    for idx, ax in enumerate(axes):
        img = images[idx].permute(1, 2, 0).numpy()
        img = (img * 0.5 + 0.5)  # Unnormalize
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(class_names[labels[idx]])
    plt.tight_layout()
    plt.show()

show_images(train_loader, class_names)

#### 2. Load ViT Feature Extractor (Pretrained)

In [None]:
# Load ViT Feature Extractor (Pretrained)
from transformers import ViTModel, ViTForImageClassification, ViTFeatureExtractor

# Instantiate ViT feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

#### 2.1. Setup Diffusion-based Image Classifier with ViT Feature Extractor

In [None]:
# Define the DiffusionClassifier with ViT feature extractor
class DiffusionClassifier(nn.Module):
    def __init__(self, vit_model, num_classes):
        super(DiffusionClassifier, self).__init__()
        self.vit_model = vit_model
        self.fc = nn.Linear(vit_model.config.hidden_size, num_classes)  # Hidden size of ViT model

    def forward(self, x):
        with torch.no_grad():  # Freeze ViT feature extractor
            features = self.vit_model(x).last_hidden_state.mean(dim=1)  # Use mean of sequence (CLS token)
        out = self.fc(features)
        return out

def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Instantiate classifier with ViT feature extractor
num_classes = len(class_names)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classifier = DiffusionClassifier(vit_model, num_classes).to(device)
print(f"Trainable Parameters: {count_trainable_params(classifier):,}")

# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)

#### 3. Train the model

In [None]:
# Training Loop
def train_model(classifier, dataloader, epochs=5):
    classifier.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}")

train_model(classifier, train_loader, epochs=20)

#### 4. Evaluate the Classifier

In [None]:
# Evaluate Model
def evaluate_model(classifier, dataloader, class_names):
    classifier.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = classifier(inputs)

            # Get probabilities and predictions
            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    return np.array(all_preds), np.array(all_labels), np.array(all_probs)

# Evaluate on Test Data
preds, labels, probs = evaluate_model(classifier, test_loader, class_names)

#### 4.1. Display Confusion Matrix and Classification Report

In [None]:
# Create Confusion Matrix
cm = confusion_matrix(labels, preds, labels=range(len(class_names)))

# Plot Confusion Matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.title("Confusion Matrix")
plt.show()

# Print Classification Report
print(classification_report(labels, preds, target_names=class_names))

#### 4.2. Plot Random Samples

In [None]:
# Plot Random Samples
def plot_random_samples(dataset, preds, labels, probs, class_names, num_samples=5):
    # Randomly select indices
    random_indices = random.sample(range(len(dataset)), num_samples)

    # Calculate number of rows needed
    num_cols = 5
    num_rows = (num_samples + num_cols - 1) // num_cols

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 3 * num_rows))
    axes = axes.flatten()

    for i, idx in enumerate(random_indices):
        image, true_label = dataset[idx]
        image = (image * 0.5 + 0.5)  # Unnormalize
        image = np.clip(image, 0, 1)
        pred_label = preds[idx]
        prob = probs[idx, pred_label]  # Probability of predicted class

        axes[i].imshow(image.permute(1, 2, 0))  # Convert CHW to HWC for plotting
        axes[i].axis("off")
        axes[i].set_title(
            f"True: {class_names[true_label]}\n"
            f"Pred: {class_names[pred_label]} ({prob:.2f})"
        )

    # Hide any remaining empty subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    plt.tight_layout()
    plt.show()

# Plot random samples
plot_random_samples(test_dataset, preds, labels, probs, class_names, num_samples=15)