In [3]:
import torch
import torchvision
from torchinfo import summary
import scipy
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import Flowers102
from torch.utils.data import DataLoader, Subset

# Define transforms for preprocessing
transforms = transforms.Compose([
    transforms.Resize(256),   # Resize the image to 256x256
    transforms.CenterCrop(224),  # Crop the center 224x224 portion of the image
    transforms.ToTensor(),    # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image (ImageNet RGB means)
])

# we resize and then center crop to remove borders 
# https://stackoverflow.com/questions/71341354/cnn-why-do-we-first-resize-the-image-to-256-and-then-center-crop-to-224
# 224 is the expected input for ImageNet models


# Load Flowers102 dataset
flowers_dataset = Flowers102(root="./data/Flowers102_dataset", split='train', transform=transforms, download=True)
val_dataset = Flowers102(root="./data/Flowers102_dataset", split='val', transform=transforms, download=True)

# Select a subset of 1000 random training images
subset_indices = torch.randperm(len(flowers_dataset))[:100]
subset_dataset = Subset(flowers_dataset, subset_indices)

# Select a subset of 200 random validation images
valset_indices = torch.randperm(len(val_dataset))[:100]
valset_dataset = Subset(flowers_dataset, subset_indices)

# Define data loader
batch_size = 32
data_loader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=True)
val_loader =  DataLoader(valset_dataset, batch_size=batch_size, shuffle=True)
# Load EfficientNet-B1 model
model = torchvision.models.efficientnet_b1(weights="DEFAULT")

# update the number of classes in the last layer
model.classifier[-1] = nn.Linear(in_features=model.classifier[-1].in_features, out_features=102, bias=True)

# # Freeze all layers except the last one
# for param in model.parameters():
#     param.requires_grad = False

In [4]:
summary(model)

Layer (type:depth-idx)                                  Param #
EfficientNet                                            --
├─Sequential: 1-1                                       --
│    └─Conv2dNormActivation: 2-1                        --
│    │    └─Conv2d: 3-1                                 864
│    │    └─BatchNorm2d: 3-2                            64
│    │    └─SiLU: 3-3                                   --
│    └─Sequential: 2-2                                  --
│    │    └─MBConv: 3-4                                 1,448
│    │    └─MBConv: 3-5                                 612
│    └─Sequential: 2-3                                  --
│    │    └─MBConv: 3-6                                 6,004
│    │    └─MBConv: 3-7                                 10,710
│    │    └─MBConv: 3-8                                 10,710
│    └─Sequential: 2-4                                  --
│    │    └─MBConv: 3-9                                 15,350
│    │    └─MBConv: 3-10       

In [6]:
from tqdm.auto import tqdm

In [7]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    print('-' * 10)
    
    # Training Phase
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(data_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(subset_dataset)
    print(f"Loss: {epoch_loss:.4f}")
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader, leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    epoch_val_loss = val_loss / len(val_dataset)
    epoch_val_acc = correct / total
    print(f"Validation Loss: {epoch_val_loss:.4f}, Accuracy: {epoch_val_acc:.4f}")

print("Training complete!")

Epoch 1/3
----------


  0%|          | 0/4 [00:00<?, ?it/s]

Loss: 4.5894


  0%|          | 0/4 [00:00<?, ?it/s]

Validation Loss: 0.4321, Accuracy: 0.3600
Epoch 2/3
----------


  0%|          | 0/4 [00:00<?, ?it/s]

Loss: 3.8969


  0%|          | 0/4 [00:00<?, ?it/s]

Validation Loss: 0.3642, Accuracy: 0.8600
Epoch 3/3
----------


  0%|          | 0/4 [00:00<?, ?it/s]

Loss: 2.9562


  0%|          | 0/4 [00:00<?, ?it/s]

Validation Loss: 0.2586, Accuracy: 0.8900
Training complete!


In [9]:
import matplotlib.pyplot as plt

In [None]:
# Visualize 20 random test images with true and predicted labels as titles
model.eval()
test_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

num_images = 5
count = 0

plt.figure(figsize=(16, 12))

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        if predicted.item() == labels.item() and count < num_images:
            image = images.squeeze().cpu().permute(1, 2, 0)
            image = (image * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])).clamp(0, 1)
            plt.subplot(4, 5, count + 1)
            plt.imshow(image)
            plt.title(f'True: {flowers_dataset.classes[labels.item()]}\nPredicted: {flowers_dataset.classes[predicted.item()]}')
            plt.axis('off')
            count += 1

        if count >= num_images:
            break

plt.tight_layout()
plt.show()

In [None]:
# # Replace the last layer with a new fully connected layer
# num_ftrs = model._fc.in_features
# model._fc = nn.Linear(num_ftrs, 102)  # 102 output classes for Flowers102 dataset

# # Add a 3-layer MLP
# class MLP(nn.Module):
#     def __init__(self, input_size, hidden_size, output_size):
#         super(MLP, self).__init__()
#         self.fc1 = nn.Linear(input_size, hidden_size)
#         self.fc2 = nn.Linear(hidden_size, hidden_size)
#         self.fc3 = nn.Linear(hidden_size, output_size)

#     def forward(self, x):
#         x = torch.relu(self.fc1(x))
#         x = torch.relu(self.fc2(x))
#         x = self.fc3(x)
#         return x

# mlp = MLP(num_ftrs, 256, 102)  # Adjust hidden size as needed

# # Combine EfficientNet and MLP
# class CombinedModel(nn.Module):
#     def __init__(self, backbone, mlp):
#         super(CombinedModel, self).__init__()
#         self.backbone = backbone
#         self.mlp = mlp

#     def forward(self, x):
#         x = self.backbone(x)
#         x = self.mlp(x)
#         return x

# combined_model = CombinedModel(model, mlp)