In [3]:
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.resnet import ResnetBlock2D
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import os
import matplotlib.pyplot as plt
import numpy as np

# 1. Define Dataset and Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

data_dir = "../modified-mini-GCD"
train_dir = os.path.join(data_dir, "train")
test_dir = os.path.join(data_dir, "test")

train_dataset = ImageFolder(root=train_dir, transform=transform)
test_dataset = ImageFolder(root=test_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Class Names
class_names = train_dataset.classes
print(f"Classes: {class_names}")

# 2. Load Stable Diffusion and Access U-Net
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
unet = pipeline.unet  # Access the U-Net
unet.eval()  # Set to evaluation mode

def adjust_groupnorm(unet):
    for module in unet.modules():
        if isinstance(module, ResnetBlock2D):
            for name, sub_module in module.named_children():
                if isinstance(sub_module, nn.GroupNorm):
                    num_channels = sub_module.num_channels
                    # Determine a valid number of groups
                    num_groups = min(32, num_channels)  # Use a reasonable maximum (e.g., 32 groups)
                    while num_channels % num_groups != 0:
                        num_groups -= 1  # Reduce until divisible
                    if num_groups <= 0:
                        num_groups = 1  # Fallback to 1 group if no valid groups found

                    # Replace the GroupNorm layer
                    new_groupnorm = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
                    print(f"Replacing GroupNorm with {num_groups} groups in {name}")
                    setattr(module, name, new_groupnorm)
    return unet

# Apply this adjustment to your U-Net
unet = adjust_groupnorm(unet)

# 3. Modify U-Net to Extract Encoder Features
class UNetFeatureExtractor(nn.Module):
    def __init__(self, unet, in_channels=3):
        super(UNetFeatureExtractor, self).__init__()
        # Input mapping layer to match U-Net's expected in_channels
        self.input_conv = nn.Conv2d(in_channels, unet.config.in_channels, kernel_size=3, padding=1)
        self.encoder = nn.ModuleList(unet.down_blocks)

    def forward(self, x):
        x = self.input_conv(x)  # Match input channels to U-Net's expected channels
        features = []
        for block in self.encoder:
            x = block(x)
            features.append(x)
        return features[-1]

# Instantiate the feature extractor
feature_extractor = UNetFeatureExtractor(unet, in_channels=3)  # Use 3 channels for RGB input

# 4. Build Classifier Using Extracted Features
class DiffusionClassifier(nn.Module):
    def __init__(self, feature_extractor, num_classes):
        super(DiffusionClassifier, self).__init__()
        self.feature_extractor = feature_extractor
        # Assume the output feature size is large (adjust based on U-Net output)
        self.fc = nn.Linear(1280, num_classes)  # Adjust input size accordingly

    def forward(self, x):
        with torch.no_grad():  # Freeze feature extractor
            features = self.feature_extractor(x)
            features = features.mean(dim=(2, 3))  # Global Average Pooling
        out = self.fc(features)
        return out

# Instantiate classifier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(class_names)
classifier = DiffusionClassifier(feature_extractor, num_classes).to(device)

# 5. Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.fc.parameters(), lr=1e-4)

def preprocess_input(inputs):
    # Ensure the input has 3 channels (RGB)
    if inputs.shape[1] != 3:  # Check if input is not 3 channels (RGB)
        raise ValueError("Input must have 3 channels (RGB).")
    return inputs

# 6. Training Loop
def train_model(classifier, dataloader, epochs=5):
    classifier.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = preprocess_input(inputs).to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}")

# Train the model
train_model(classifier, train_loader, epochs=5)

# Save the model
torch.save(classifier.state_dict(), "diffusion_classifier.pth")

# Load the model
classifier.load_state_dict(torch.load("diffusion_classifier.pth"))

# 7. Evaluate the Model
def evaluate_model(classifier, dataloader):
    classifier.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = classifier(inputs)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return all_preds, all_labels

preds, labels = evaluate_model(classifier, test_loader)

# 8. Calculate Accuracy
accuracy = np.mean(np.array(preds) == np.array(labels))
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Classes: ['1_clearsky', '2_cloudy']


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 groups in norm1
Replacing GroupNorm with 32 groups in norm2
Replacing GroupNorm with 32 grou

AttributeError: 'UNetFeatureExtractor' object has no attribute 'unet'