### Vision Transformer on Modified Mini-GCD Dataset

#### 0. Setup Environment

In [None]:
import torch
from sympy.strategies.core import switch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())
# Get the number of GPUs
print(torch.cuda.device_count())
# Get the current GPU device
print(torch.cuda.current_device())
print(torch.cuda.get_device_name())

#### 1. Load Simplified and Modified Mini-GCD Dataset

In [None]:
import torch
from diffusers import StableDiffusionPipeline
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torch import nn
import matplotlib.pyplot as plt
import os
import numpy as np
import transformers
import accelerate

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
])

# Load Dataset
data_dir = "../modified-mini-GCD"
train_dir = os.path.join(data_dir, "train")
test_dir = os.path.join(data_dir, "test")

train_dataset = ImageFolder(root=train_dir, transform=transform)
test_dataset = ImageFolder(root=test_dir, transform=transform)

# Class Names
class_names = train_dataset.classes  # ['1_clearsky', '2_cloudy', '3_overcast']
print(f"Classes: {class_names}")

# Show number of training samples per class
train_class_counts = {class_names[i]: 0 for i in range(len(class_names))}
for i, (image, label) in enumerate(train_dataset):
    print(i, image.shape, label)
    train_class_counts[class_names[label]] += 1
print("Train Class Counts:", train_class_counts)

In [None]:
from torch import nn
from torch.utils.data import WeightedRandomSampler, DataLoader

# Calculate sampling weights for each sample
class_sample_count = [train_class_counts[name] for name in class_names]  # Class counts
class_weights = [1.0 / count for count in class_sample_count]  # Inverse of class frequency
sample_weights = [class_weights[label] for _, label in train_dataset]  # Assign weight to each sample
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=8, sampler=sampler)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

for i, (image, label) in enumerate(train_loader):
    print(i, image.shape, label)

#### 1.1. Display Sample Images

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Display Sample Images
def show_images(dataloader, class_names, num_samples=5):
    images, labels = [], []
    for batch_images, batch_labels in dataloader:
        images.append(batch_images)
        labels.append(batch_labels)
        if len(images) * batch_images.size(0) >= num_samples:
            break

    images = torch.cat(images)[:num_samples]
    labels = torch.cat(labels)[:num_samples]

    num_cols = 5
    num_rows = (num_samples + num_cols - 1) // num_cols

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 3 * num_rows))
    axes = axes.flatten()

    for idx, ax in enumerate(axes):
        if idx < num_samples:
            img = images[idx].permute(1, 2, 0).numpy()  # Convert to HWC format
            # Unnormalize according to ImageNet normalization
            img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img = np.clip(img, 0, 1)
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(class_names[labels[idx]])
        else:
            ax.axis('off')

    plt.tight_layout()
    plt.show()

# Show images before training
show_images(train_loader, class_names, num_samples=50)

#### 2. Load Pretrained Stable Diffusion Model

In [None]:
# Load Pretrained Stable Diffusion Model
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

# Disable generation (weâ€™ll use the feature extractor)
pipeline.enable_attention_slicing()

#### 2.1. Define Fine-Tuning Classifier

In [None]:
from timm import create_model

def create_resnet_feature_extractor() -> nn.Module:
    extr = nn.Sequential(*list(resnet18(pretrained=True).children())[:-1])
    return extr

def create_vit_feature_extractor() -> nn.Module:
    model = create_model('vit_base_patch16_224', pretrained=True, img_size=256)
    return nn.Sequential(*list(model.children())[:-1])  # Use layers before classification head

feature_extractor = create_vit_feature_extractor()

# Update DiffusionClassifier
class DiffusionClassifier(nn.Module):
    def __init__(self, feature_extractor, num_classes):
        super(DiffusionClassifier, self).__init__()
        self.feature_extractor = feature_extractor
        self.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(768, num_classes)  # Ensure input size matches feature extractor output
        )

    def forward(self, x):
        # Use frozen feature extractor
        with torch.no_grad():
            features = self.feature_extractor(x).mean(dim=1)  # Global average pooling
        out = self.fc(features)
        return out

def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Instantiate Classifier and Print Trainable Parameters
num_classes = len(class_names)
classifier = DiffusionClassifier(feature_extractor, num_classes).to(device)
print(f"Trainable Parameters: {count_trainable_params(classifier):,}")

# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)

#### 3. Train the Classifier

In [None]:
import random

# Training Loop
def train_model(classifier, dataloader, epochs=5):
    classifier.train()  # Set model to training mode
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()  # Reset gradients
            outputs = classifier(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backward pass
            optimizer.step()  # Update weights
            running_loss += loss.item()  # Accumulate loss

            # Plot a random image with its real label
            idx = random.randint(0, len(inputs) - 1)

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(dataloader)}")

# Train the model
train_model(classifier, train_loader, epochs=20)

#### 4. Evaluate the Classifier

In [None]:
import torch
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np
import random

# Function to plot a given image with its real label
def plot_image(image, label, class_names):
    img = image.permute(1, 2, 0).cpu().numpy()  # Convert to HWC format
    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Unnormalize
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    plt.title(f"Label: {class_names[label]}")
    plt.axis('off')
    plt.show()

def validate_model(classifier, dataloader):
    classifier.eval()  # Set model to evaluation mode
    total, correct = 0, 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = classifier(inputs)
            _, preds = torch.max(outputs, 1)  # Get class with highest score
            print(preds, labels)
            for i in range(len(preds)):
                plot_image(inputs[i], labels[i], class_names)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    accuracy = correct / total
    print(f"Validation Accuracy: {accuracy:.2%}")

def evaluate_model(classifier, dataloader, class_names):
    classifier.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = classifier(inputs)

            # Get probabilities and predictions
            # print(outputs)
            probs = torch.softmax(outputs, dim=1)
            # print(probs)
            preds = torch.argmax(probs, dim=1)
            # print(probs)
            # print(preds)
            # print(labels)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    return np.array(all_preds), np.array(all_labels), np.array(all_probs)

# Validate
# validate_model(classifier, test_loader)
validate_model(classifier, train_loader)

# Evaluate
# preds, labels, probs = evaluate_model(classifier, test_loader, class_names)
preds, labels, probs = evaluate_model(classifier, train_loader, class_names)

#### 4.1. Display Confusion Matrix and Classification Report

In [None]:
# Create Confusion Matrix
cm = confusion_matrix(labels, preds, labels=range(len(class_names)))

# Plot Confusion Matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.title("Confusion Matrix")
plt.show()

# Print Classification Report
print(classification_report(labels, preds, target_names=class_names))


#### 4.2. Plot Random Samples

In [None]:
def plot_random_samples(dataset, preds, labels, probs, class_names, num_samples=5):
    # Randomly select indices
    random_indices = random.sample(range(len(dataset)), num_samples)

    # Calculate number of rows needed
    num_cols = 5
    num_rows = (num_samples + num_cols - 1) // num_cols

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 3 * num_rows))
    axes = axes.flatten()

    for i, idx in enumerate(random_indices):
        image, true_label = dataset[idx]
        print(idx, len(image), image.shape, true_label)
        image = image * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        image = np.clip(image, 0, 1)
        pred_label = preds[idx]
        prob = probs[idx, pred_label]  # Probability of predicted class

        axes[i].imshow(image.permute(1, 2, 0))  # Convert CHW to HWC for plotting
        axes[i].axis("off")
        axes[i].set_title(
            f"True: {class_names[true_label]}\n"
            f"Pred: {class_names[pred_label]} ({prob:.2f})"
        )

    # Hide any remaining empty subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    plt.tight_layout()
    plt.show()

# Plot Random Samples
# plot_random_samples(test_dataset, preds, labels, probs, class_names)
plot_random_samples(train_dataset, preds, labels, probs, class_names, num_samples=50)
