# U-Net Fine-Tuning for Image Classification on Modified Mini-GCD Dataset with Classification Head

### 0. Setup Environment

#### 0.1. Install Required Libraries

In [29]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import segmentation_models_pytorch as smp
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

#### 0.2. Check GPU Availability

In [30]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

True
1
0
NVIDIA GeForce GTX 1080 Ti
cuda


In [31]:
# Hyperparameters
num_epochs = 50

#### 1. Define Dataset and Transformations

In [32]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

data_dir = "modified-mini-GCD"
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "test")  # Assuming you have a validation set
test_dir = os.path.join(data_dir, "test")

train_dataset = ImageFolder(root=train_dir, transform=transform)
val_dataset = ImageFolder(root=val_dir, transform=transform)
test_dataset = ImageFolder(root=test_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

#### 2. Load the pre-trained U-Net model

#### 2.1. Load the Pre-trained U-Net Model and Print the Number of Trainable Parameters

In [33]:
# 2. Load the pre-trained U-Net model
unet = smp.Unet(encoder_name="resnet34", encoder_weights="imagenet", classes=1, activation=None)

# Print the number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Number of trainable parameters in U-Net: {count_parameters(unet)}')

Number of trainable parameters in U-Net: 24436369


#### 2.2. Add a Classification Head to the U-Net

In [34]:
class UNetClassifier(nn.Module):
    def __init__(self, unet, num_classes):
        super(UNetClassifier, self).__init__()
        self.unet = unet
        # Get the number of output channels from the final layer of the U-Net
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(unet.segmentation_head[0].out_channels, num_classes)
        )

    def forward(self, x):
        features = self.unet(x)
        out = self.classifier(features)
        return out

num_classes = len(train_dataset.classes)
model = UNetClassifier(unet, num_classes)

### 3. Prepare for Training

#### 3.1. Define Loss Function and Optimizer

In [35]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

#### 3.2. Train the Classifier

In [36]:
# Training and validation loop
num_epochs = 10
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    train_loss = running_loss / len(train_loader.dataset)
    train_losses.append(train_loss)

    # Validation loop
    model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item() * inputs.size(0)

    val_loss = val_running_loss / len(val_loader.dataset)
    val_losses.append(val_loss)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

Epoch 1/50, Loss: 0.2645


KeyboardInterrupt: 

#### 3.3. Plot Training and Validation Losses

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs+1), val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

### 4. Persist the Model

In [None]:
# Save the model
def save_model(model, path):
    torch.save(model.state_dict(), path)

# Load the model
def load_model(model, path):
    model.load_state_dict(torch.load(path))
    model.eval()

# Save the trained model
save_model(model, 'unet_classifier.pth')

# Load the model back
loaded_model = UNetClassifier(unet, num_classes)
load_model(loaded_model, 'unet_classifier.pth')

### 5. Evaluate the model

In [None]:
# Evaluate the model and print metrics
def evaluate_model(model, test_loader, device):
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    # Convert lists to numpy arrays
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)

    # Print evaluation metrics
    print(f'Accuracy: {accuracy_score(all_labels, all_preds):.4f}')
    print(f'Precision: {precision_score(all_labels, all_preds, average="weighted"):.4f}')
    print(f'Recall: {recall_score(all_labels, all_preds, average="weighted"):.4f}')
    print(f'F1 Score: {f1_score(all_labels, all_preds, average="weighted"):.4f}')
    print(f'AUC-ROC: {roc_auc_score(all_labels, all_probs, multi_class="ovr"):.4f}')

    # Print classification report
    print(classification_report(all_labels, all_preds, target_names=train_dataset.classes))

    # Plot confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=train_dataset.classes, yticklabels=train_dataset.classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

# Evaluate the loaded model
evaluate_model(loaded_model, test_loader, device)