In [None]:
!pip install torchinfo

In [None]:
import torch
import torchvision
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from sklearn.model_selection import StratifiedKFold
from tqdm.auto import tqdm
from typing import List, Dict, Tuple
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
# For Trainable Layer
pretrained_efficientnet_weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT

# 2. Setup EfficientNet model instance with pretrained weights
pretrained_efficientnet = torchvision.models.efficientnet_v2_s(weights=pretrained_efficientnet_weights).to(device)

# 3. Freeze the base parameters (optional - set to True as in your ViT example)
for parameter in pretrained_efficientnet.parameters():
    parameter.requires_grad = False
# for parameter in pretrained_efficientnet.parameters():
#     parameter.requires_grad = False

# # Unfreeze parameters in the last two layers (classifier)
# for parameter in pretrained_efficientnet.features[8].parameters():
#     parameter.requires_grad = True

# 4. Change the classifier head
class_names = ['defective', 'non-defective']
pretrained_efficientnet.classifier = nn.Sequential(
    nn.Dropout(p=0.2),
    nn.Linear(in_features=1280, out_features=32),   
    nn.ReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(in_features=32, out_features=1),
    nn.Sigmoid()
).to(device)

In [None]:
# from torchinfo import summary
# summary(model=pretrained_efficientnet,
#         input_size=(32, 3, 224, 224),
#         col_names=["input_size", "output_size", "num_params", "trainable"],
#         col_width=20,
#         row_settings=["var_names"])

In [None]:
pretrained_efficientnet_transforms = pretrained_efficientnet_weights.transforms()
print(pretrained_efficientnet_transforms)


In [None]:
# For Full DataSet
# NUM_WORKERS = os.cpu_count()

# def create_dataloaders_with_cross_validation(
#     dataset_dir: str,
#     transform: transforms.Compose,
#     batch_size: int,
#     num_splits: int = 5,
#     num_workers: int = NUM_WORKERS
# ):
#     # Use ImageFolder to create the dataset
#     full_dataset = datasets.ImageFolder(dataset_dir, transform=transform)

#     # Initialize StratifiedKFold for cross-validation
#     skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=42)

#     # Initialize lists to store train and test data loaders for each split
#     train_dataloaders = []
#     test_dataloaders = []

#     # Get class names
#     class_names = full_dataset.classes

#     for train_indices, test_indices in skf.split(full_dataset.imgs, full_dataset.targets):
#         # Create train and test datasets for the current split
#         train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
#         test_dataset = torch.utils.data.Subset(full_dataset, test_indices)

#         # Turn images into data loaders
#         train_dataloader = DataLoader(
#             train_dataset,
#             batch_size=batch_size,
#             shuffle=True,
#             num_workers=num_workers,
#             pin_memory=True,
#         )
#         test_dataloader = DataLoader(
#             test_dataset,
#             batch_size=batch_size,
#             shuffle=False,
#             num_workers=num_workers,
#             pin_memory=True,
#         )

#         train_dataloaders.append(train_dataloader)
#         test_dataloaders.append(test_dataloader)

#     return train_dataloaders, test_dataloaders, class_names

In [None]:
#For Half Dataset
NUM_WORKERS = os.cpu_count()
def create_dataloaders_with_cross_validation(
    dataset_dir: str,
    transform: transforms.Compose,
    batch_size: int,
    sampling_ratio: float = 0.5,  # Added sampling ratio parameter
    num_splits: int = 5,
    num_workers: int = NUM_WORKERS
):
    # Use ImageFolder to create the dataset
    full_dataset = datasets.ImageFolder(dataset_dir, transform=transform)
    
    # Get indices for each class
    class_indices = {i: [] for i in range(len(full_dataset.classes))}
    for idx, (_, label) in enumerate(full_dataset):
        class_indices[label].append(idx)
    
    # Randomly sample indices from each class
    sampled_indices = []
    for class_idx, indices in class_indices.items():
        n_samples = int(len(indices) * sampling_ratio)
        sampled_indices.extend(np.random.choice(indices, size=n_samples, replace=False))
    
    # Shuffle the sampled indices
    np.random.shuffle(sampled_indices)
    
    # Create a subset of the dataset with only sampled indices
    sampled_dataset = torch.utils.data.Subset(full_dataset, sampled_indices)
    sampled_targets = [full_dataset.targets[i] for i in sampled_indices]
    
    print(f"Original dataset size: {len(full_dataset)}")
    print(f"Sampled dataset size: {len(sampled_dataset)} ({sampling_ratio*100}%)")
    
    # Initialize StratifiedKFold for cross-validation
    skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=42)
    # Initialize lists to store train and test data loaders for each split
    train_dataloaders = []
    test_dataloaders = []
    # Get class names
    class_names = full_dataset.classes
    
    for train_indices, test_indices in skf.split(range(len(sampled_dataset)), sampled_targets):
        # Create train and test datasets for the current split
        train_dataset = torch.utils.data.Subset(sampled_dataset, train_indices)
        test_dataset = torch.utils.data.Subset(sampled_dataset, test_indices)
        # Turn images into data loaders
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
        )
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
        )
        train_dataloaders.append(train_dataloader)
        test_dataloaders.append(test_dataloader)
    return train_dataloaders, test_dataloaders, class_names


In [None]:
from google.colab import drive
drive.mount("/content/gdrive")
dataset_dir = "/content/gdrive/MyDrive/Colab Notebooks/CMID/" 

In [None]:
# For Full Dataset
# train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders_with_cross_validation(
#     dataset_dir=dataset_dir,
#     transform=pretrained_efficientnet_transforms,
#     batch_size=32,
#     num_splits=5
# )

In [None]:
# For Half Dataset
train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders_with_cross_validation(
    dataset_dir=dataset_dir,
    transform=pretrained_efficientnet_transforms,
    batch_size=32,
    sampling_ratio=0.5,  # Use 50% of the data
    num_splits=5
)

In [None]:
def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               device: torch.device) -> Tuple[float, float]:
    """Performs one epoch of training."""
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Forward pass
        y_pred = model(X)
        loss = loss_fn(y_pred.squeeze(1), y.type(torch.float32))  
        train_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accuracy calculation
        y_pred_class = (y_pred.squeeze(1) > 0.5).float()
        correct += (y_pred_class == y).sum().item()
        total += y.size(0)

    train_loss /= len(dataloader)
    train_acc = correct / total
    return train_loss, train_acc

def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module,
              device: torch.device) -> Tuple[float, float, List[int], List[int]]:
    """Performs one epoch of evaluation."""
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    all_labels, all_preds = [], []

    with torch.inference_mode():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            # Forward pass
            y_pred = model(X)
            loss = loss_fn(y_pred.squeeze(1), y.type(torch.float32))  
            test_loss += loss.item()

            # Predictions and accuracy
            y_pred_class = (y_pred.squeeze(1) > 0.5).float()
            correct += (y_pred_class == y).sum().item()
            total += y.size(0)

            # Collect labels and predictions for precision/recall/F1
            all_labels.extend(y.cpu().tolist())
            all_preds.extend(y_pred_class.cpu().tolist())

    test_loss /= len(dataloader)
    test_acc = correct / total
    return test_loss, test_acc, all_labels, all_preds
    
def train_with_cross_validation(model: torch.nn.Module,
                              train_dataloaders: List[torch.utils.data.DataLoader],
                              test_dataloaders: List[torch.utils.data.DataLoader],
                              optimizer: torch.optim.Optimizer,
                              loss_fn: torch.nn.Module,
                              epochs: int,
                              device: torch.device) -> List[Dict[str, List]]:
    """Trains a model with cross-validation and computes metrics."""
    all_results = []

    for split in range(len(train_dataloaders)):
        results = {"train_loss": [], "train_acc": [],
                  "test_loss": [], "test_acc": [],
                  "precision": [], "recall": [], "f1": []}
        model.to(device)

        for epoch in tqdm(range(epochs)):
            train_loss, train_acc = train_step(model, train_dataloaders[split], loss_fn, optimizer, device)
            test_loss, test_acc, all_labels, all_preds = test_step(model, test_dataloaders[split], loss_fn, device)

            # Calculate precision, recall, and F1 score
            precision = precision_score(all_labels, all_preds, average="weighted")
            recall = recall_score(all_labels, all_preds, average="weighted")
            f1 = f1_score(all_labels, all_preds, average="weighted")

            # Log results
            print(
                f"Split: {split+1} | Epoch: {epoch+1} | "
                f"train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | "
                f"test_loss: {test_loss:.4f} | test_acc: {test_acc:.4f} | "
                f"precision: {precision:.4f} | recall: {recall:.4f} | f1: {f1:.4f}"
            )

            # Store results
            results["train_loss"].append(train_loss)
            results["train_acc"].append(train_acc)
            results["test_loss"].append(test_loss)
            results["test_acc"].append(test_acc)
            results["precision"].append(precision)
            results["recall"].append(recall)
            results["f1"].append(f1)

        # Calculate mean metrics for this split
        mean_precision = np.mean(results["precision"])
        mean_recall = np.mean(results["recall"])
        mean_f1 = np.mean(results["f1"])
        print(f"Split {split+1} Mean Metrics - Precision: {mean_precision:.4f}, Recall: {mean_recall:.4f}, F1: {mean_f1:.4f}")

        all_results.append(results)

    return all_results

In [None]:
optimizer = torch.optim.Adam(params=pretrained_efficientnet.parameters(),
                           lr=1e-3)
loss_fn = nn.BCELoss()

# Train the model
pretrained_efficientnet_results = train_with_cross_validation(
    model=pretrained_efficientnet,
    train_dataloaders=train_dataloader_pretrained,
    test_dataloaders=test_dataloader_pretrained,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=100,
    device=device
)