In [86]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision import models
import numpy as np
import cv2
from skimage.filters import sobel
from skimage.measure import shannon_entropy
from tqdm import tqdm
from scipy.stats import entropy
from scipy.ndimage import sobel


In [87]:
def compute_color_variance(image):
    """
    Compute color variance for an image.
    Args:
        image (Tensor): Input image of shape (3, H, W).
    Returns:
        float: Color variance.
    """
    image = image.cpu().numpy()  # Convert to NumPy array
    variance = np.var(image, axis=(1, 2))  # Variance across H, W for each channel
    return variance.mean()  # Average variance across RGB channels

def compute_edge_density(image):
    """
    Compute edge density for an image using Sobel filter.
    Args:
        image (Tensor): Input image of shape (3, H, W).
    Returns:
        float: Edge density.
    """
    image_gray = image.mean(dim=0).cpu().numpy()  # Convert to grayscale
    edges = sobel(image_gray)  # Sobel filter for edge detection
    edge_density = np.sum(edges > 0) / edges.size  # Proportion of edge pixels
    return edge_density

def compute_entropy(image):
    """
    Compute entropy for an image.
    Args:
        image (Tensor): Input image of shape (3, H, W).
    Returns:
        float: Entropy value.
    """
    image_gray = image.mean(dim=0).cpu().numpy()  # Convert to grayscale
    hist, _ = np.histogram(image_gray, bins=256, range=(0, 1), density=True)  # Normalized histogram
    return entropy(hist + 1e-6)  # Avoid log(0) with small offset

In [88]:
class EarlyExitResNet34(nn.Module):
    def __init__(self, num_classes, input_shape=(3, 224, 224), initial_threshold=0.8, decay_rate=0.1):
        super(EarlyExitResNet34, self).__init__()
        
        # Load pre-trained ResNet34
        self.resnet34 = models.resnet34(pretrained=True)
        
        # Replace the final classifier to match the target classes
        self.resnet34.fc = nn.Linear(self.resnet34.fc.in_features, num_classes)
        
        # Compute feature map sizes for early exits dynamically
        self.exit1_size = self._compute_flattened_size(
            nn.Sequential(
                self.resnet34.conv1,
                self.resnet34.bn1,
                self.resnet34.relu,
                self.resnet34.maxpool
            ),
            input_shape
        )
        self.exit2_size = self._compute_flattened_size(
            nn.Sequential(
                self.resnet34.conv1,
                self.resnet34.bn1,
                self.resnet34.relu,
                self.resnet34.maxpool,
                self.resnet34.layer1
            ),
            input_shape
        )
        
        # Early exit fully connected layers
        self.exit1_fc = nn.Linear(self.exit1_size, num_classes)
        self.exit2_fc = nn.Linear(self.exit2_size, num_classes)
        
        # Early exit thresholds
        self.initial_threshold = initial_threshold
        self.decay_rate = decay_rate

    def _compute_flattened_size(self, layers, input_shape):
        """
        Helper function to compute the flattened size of feature maps after a sequence of layers.
        """
        with torch.no_grad():
            dummy_input = torch.zeros(1, *input_shape)  # Batch size of 1
            output = layers(dummy_input)
        return output.view(1, -1).size(1)

    def _should_exit(self, x, fc, threshold):
        """
        Helper function to determine if the model should exit early based on confidence.
        """
        x_exit = torch.flatten(x, start_dim=1)  # Flatten the feature map
        logits = fc(x_exit)  # Pass through the exit classifier
        probs = F.softmax(logits, dim=1)
        confidence, _ = torch.max(probs, dim=1)
        return confidence.item() > threshold, logits
    
    def _compute_dynamic_threshold(self, image):
        """
        Compute dynamic thresholds based on image metrics.
        Args:
            image (Tensor): Input image of shape (B, 3, H, W).
        Returns:
            float: Adjusted threshold for early exits.
        """
        batch_size = image.size(0)
        thresholds = []
        for i in range(batch_size):
            color_var = compute_color_variance(image[i])
            edge_density = compute_edge_density(image[i])
            entropy_value = compute_entropy(image[i])

            # Example formula for dynamic threshold
            dynamic_threshold = (
                0.5 * (color_var / 255) +  # Normalize variance
                0.3 * edge_density +
                0.2 * (entropy_value / np.log(256))  # Normalize entropy
            )
            thresholds.append(dynamic_threshold)

        return thresholds

    def forward(self, x, thresholds=None):
        if thresholds is None:
            thresholds = self._compute_dynamic_threshold(x)  # Dynamically compute thresholds

        # Initial layers
        x = self.resnet34.conv1(x)
        x = self.resnet34.bn1(x)
        x = self.resnet34.relu(x)
        x = self.resnet34.maxpool(x)

        # Early exit 1
        if thresholds[0] is not None:
            should_exit, logits = self._should_exit(x, self.exit1_fc, thresholds[0])
            if should_exit:
                return logits

        # ResNet block 1
        x = self.resnet34.layer1(x)

        # Early exit 2
        if thresholds[1] is not None:
            should_exit, logits = self._should_exit(x, self.exit2_fc, thresholds[1])
            if should_exit:
                return logits

        # Remaining ResNet layers
        x = self.resnet34.layer2(x)
        x = self.resnet34.layer3(x)
        x = self.resnet34.layer4(x)

        # Final classifier
        x = self.resnet34.avgpool(x)
        x = torch.flatten(x, 1)
        logits = self.resnet34.fc(x)

        return logits

In [81]:
# --- Data Preparation ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
])

# Load CIFAR-10 as an example dataset
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
val_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [82]:
# --- Model, Loss, and Optimizer ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EarlyExitResNet34(num_classes=10, input_shape=(3, 224, 224)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [93]:
def train(model, train_loader, optimizer, criterion, device, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)  # Dynamic thresholds applied here

            # Compute loss
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Update metrics
            running_loss += loss.item()
            _, predicted = torch.max(outputs, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = correct / total * 100
        print(f"Epoch {epoch + 1}: Loss = {epoch_loss:.4f}, Accuracy = {epoch_accuracy:.2f}%")
    torch.save(model.state_dict(), "resnet34_early_exit_cifar10_adaptive.pth")
    print("Model saved successfully!")

In [79]:
# --- Validation Function ---
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images, thresholds=None)  # No early exit during validation
            loss = criterion(outputs, labels)

            # Update metrics
            val_loss += loss.item()
            _, predicted = torch.max(outputs, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    val_loss /= len(val_loader)
    val_accuracy = correct / total * 100
    print(f"Validation: Loss = {val_loss:.4f}, Accuracy = {val_accuracy:.2f}%")
    return val_loss, val_accuracy

In [94]:
# --- Training and Validation Loop ---
for epoch in range(10):
    train(model, train_loader, optimizer, criterion, device, num_epochs=1)
    validate(model, val_loader, criterion, device)

Epoch 1/1: 100%|██████████| 391/391 [1:11:02<00:00, 10.90s/it]


Epoch 1: Loss = 0.1658, Accuracy = 94.25%
Model saved successfully!
Validation: Loss = 0.3554, Accuracy = 89.00%


Epoch 1/1: 100%|██████████| 391/391 [1:10:25<00:00, 10.81s/it]


Epoch 1: Loss = 0.1243, Accuracy = 95.68%
Model saved successfully!
Validation: Loss = 0.2959, Accuracy = 91.07%


Epoch 1/1: 100%|██████████| 391/391 [1:10:07<00:00, 10.76s/it]


Epoch 1: Loss = 0.1024, Accuracy = 96.40%
Model saved successfully!
Validation: Loss = 0.3274, Accuracy = 90.11%


Epoch 1/1: 100%|██████████| 391/391 [1:10:09<00:00, 10.76s/it]


Epoch 1: Loss = 0.0758, Accuracy = 97.41%
Model saved successfully!
Validation: Loss = 0.3435, Accuracy = 90.17%


Epoch 1/1: 100%|██████████| 391/391 [1:10:17<00:00, 10.79s/it]


Epoch 1: Loss = 0.0634, Accuracy = 97.79%
Model saved successfully!
Validation: Loss = 0.3071, Accuracy = 91.32%


Epoch 1/1: 100%|██████████| 391/391 [1:10:02<00:00, 10.75s/it]


Epoch 1: Loss = 0.0601, Accuracy = 97.94%
Model saved successfully!
Validation: Loss = 0.3794, Accuracy = 90.07%


Epoch 1/1: 100%|██████████| 391/391 [1:10:12<00:00, 10.77s/it]


Epoch 1: Loss = 0.0554, Accuracy = 98.07%
Model saved successfully!
Validation: Loss = 0.2808, Accuracy = 91.97%


Epoch 1/1: 100%|██████████| 391/391 [1:11:35<00:00, 10.99s/it]


Epoch 1: Loss = 0.0390, Accuracy = 98.68%
Model saved successfully!
Validation: Loss = 0.3540, Accuracy = 91.08%


Epoch 1/1: 100%|██████████| 391/391 [1:08:53<00:00, 10.57s/it]


Epoch 1: Loss = 0.0442, Accuracy = 98.49%
Model saved successfully!
Validation: Loss = 0.4110, Accuracy = 90.18%


Epoch 1/1: 100%|██████████| 391/391 [1:07:59<00:00, 10.43s/it]


Epoch 1: Loss = 0.0430, Accuracy = 98.50%
Model saved successfully!
Validation: Loss = 0.3273, Accuracy = 91.83%


In [95]:
def test_model(model, test_loader, device):
    """
    Test the model on a given test dataset.
    Args:
        model (nn.Module): Trained model.
        test_loader (DataLoader): DataLoader for the test dataset.
        device (torch.device): Device for computation.
    """
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    early_exit_counts = [0] * (model.num_exits + 1)  # Count per exit point

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass with dynamic thresholds
            outputs = model(images)

            # Determine the exit point
            if isinstance(outputs, tuple):  # Outputs can be (logits, exit_point)
                logits, exit_point = outputs
                early_exit_counts[exit_point] += 1
            else:
                logits = outputs
                early_exit_counts[-1] += 1  # Count for the final exit

            # Compute predictions
            _, predicted = torch.max(logits, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    # Overall accuracy
    accuracy = correct / total * 100
    print(f"Test Accuracy: {accuracy:.2f}%")

    # Exit point statistics
    for i, count in enumerate(early_exit_counts):
        if i < model.num_exits:
            print(f"Exit {i + 1}: {count} samples exited")
        else:
            print(f"Final Exit: {count} samples exited")

In [96]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


Files already downloaded and verified


In [99]:
# --- Load Trained Model ---
model = EarlyExitResNet34(num_classes=10, input_shape=(3, 224, 224))
model.load_state_dict(torch.load("resnet34_early_exit_cifar10_adaptive.pth"))  # Replace with the correct path
model.to(device)

# --- Test the Model ---
test_model(model, test_loader, device)

  model.load_state_dict(torch.load("resnet34_early_exit_cifar10_adaptive.pth"))  # Replace with the correct path


AttributeError: 'EarlyExitResNet34' object has no attribute 'num_exits'