In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision import models
import numpy as np
import cv2
from skimage.filters import sobel
from skimage.measure import shannon_entropy
from tqdm import tqdm
from scipy.stats import entropy
from scipy.ndimage import sobel


In [2]:
def compute_color_variance(image):
    """
    Compute color variance for an image.
    Args:
        image (Tensor): Input image of shape (3, H, W).
    Returns:
        float: Color variance.
    """
    image = image.cpu().numpy()  # Convert to NumPy array
    variance = np.var(image, axis=(1, 2))  # Variance across H, W for each channel
    return variance.mean()  # Average variance across RGB channels

def compute_edge_density(image):
    """
    Compute edge density for an image using Sobel filter.
    Args:
        image (Tensor): Input image of shape (3, H, W).
    Returns:
        float: Edge density.
    """
    image_gray = image.mean(dim=0).cpu().numpy()  # Convert to grayscale
    edges = sobel(image_gray)  # Sobel filter for edge detection
    edge_density = np.sum(edges > 0) / edges.size  # Proportion of edge pixels
    return edge_density

def compute_entropy(image):
    """
    Compute entropy for an image.
    Args:
        image (Tensor): Input image of shape (3, H, W).
    Returns:
        float: Entropy value.
    """
    image_gray = image.mean(dim=0).cpu().numpy()  # Convert to grayscale
    hist, _ = np.histogram(image_gray, bins=256, range=(0, 1), density=True)  # Normalized histogram
    return entropy(hist + 1e-6)  # Avoid log(0) with small offset

In [None]:
class EarlyExitResNet34(nn.Module):
    def __init__(self, num_classes, input_shape=(3, 224, 224), initial_threshold=0.8, decay_rate=0.1):
        super(EarlyExitResNet34, self).__init__()
        
        # Load pre-trained ResNet34
        #self.resnet34 = models.resnet34(pretrained=True)
        self.resnet34 = models.resnet34(weights=None)
        # Replace the final classifier to match the target classes
        self.resnet34.fc = nn.Linear(self.resnet34.fc.in_features, num_classes)
        
        # Compute feature map sizes for early exits dynamically
        self.exit1_size = self._compute_flattened_size(
            nn.Sequential(
                self.resnet34.conv1,
                self.resnet34.bn1,
                self.resnet34.relu,
                self.resnet34.maxpool
            ),
            input_shape
        )
        self.exit2_size = self._compute_flattened_size(
            nn.Sequential(
                self.resnet34.conv1,
                self.resnet34.bn1,
                self.resnet34.relu,
                self.resnet34.maxpool,
                self.resnet34.layer1
            ),
            input_shape
        )
        
        # Early exit fully connected layers
        self.exit1_fc = nn.Linear(self.exit1_size, num_classes)
        self.exit2_fc = nn.Linear(self.exit2_size, num_classes)
        
        # Early exit thresholds
        self.initial_threshold = initial_threshold
        self.decay_rate = decay_rate

    def _compute_flattened_size(self, layers, input_shape):
        """
        Helper function to compute the flattened size of feature maps after a sequence of layers.
        """
        with torch.no_grad():
            dummy_input = torch.zeros(1, *input_shape)  # Batch size of 1
            output = layers(dummy_input)
        return output.view(1, -1).size(1)

    def _should_exit(self, x, fc, threshold):
        """
        Helper function to determine if the model should exit early based on confidence.
        """
        if isinstance(threshold, list) or isinstance(threshold, torch.Tensor):
            raise ValueError("Threshold should be a scalar, not a list or tensor.")

        x_exit = torch.flatten(x, start_dim=1)  # Flatten the feature map
        logits = fc(x_exit)  # Pass through the exit classifier
        probs = F.softmax(logits, dim=1)
        confidence, _ = torch.max(probs, dim=1)
        # Ensure comparison is element-wise
        return confidence > threshold, logits
    
    def _compute_dynamic_threshold(self, image):
        """
        Compute dynamic thresholds based on image metrics.
        Args:
            image (Tensor): Input image of shape (B, 3, H, W).
        Returns:
            float: Adjusted threshold for early exits.
        """
        batch_size = image.size(0)
        thresholds = []
        for i in range(batch_size):
            color_var = compute_color_variance(image[i])
            edge_density = compute_edge_density(image[i])
            entropy_value = compute_entropy(image[i])

            # Example formula for dynamic threshold
            dynamic_threshold = (
                0.5 * (color_var / 255) +  # Normalize variance
                0.4 * edge_density +
                0.3 * (entropy_value / np.log(256))  # Normalize entropy
            )
            thresholds.append(dynamic_threshold)

        return thresholds

    def forward(self, x, thresholds=None):
        #print(f"Input shape: {x.shape}")
        
        # Ensure thresholds are set
        if thresholds is None:
            thresholds = self._compute_dynamic_threshold(x)
            #print(f"Computed dynamic thresholds: {thresholds}")
        elif isinstance(thresholds, (float, int)):
            thresholds = [thresholds] * x.size(0)
        
        # Convert thresholds to scalars
        thresholds = [t.item() if isinstance(t, torch.Tensor) else t for t in thresholds]
        #print(f"Final thresholds for batch: {thresholds}")

        # Initial layers
        x = self.resnet34.conv1(x)
        x = self.resnet34.bn1(x)
        x = self.resnet34.relu(x)
        x = self.resnet34.maxpool(x)

        # Early exit 1
        if thresholds[0] is not None:
            should_exit, logits = self._should_exit(x, self.exit1_fc, thresholds[0])
            #print(f"Exit 1 decision: {should_exit}")
            if should_exit.any():  # Exit if any image in the batch should exit
                return logits

        # ResNet block 1
        x = self.resnet34.layer1(x)
        # Early exit 2
        if thresholds[0] is not None:
            should_exit, logits = self._should_exit(x, self.exit2_fc, thresholds[0])
            #print(f"Exit 2 decision: {should_exit}")
            if should_exit.any():  # Exit if any image in the batch should exit
                return logits
        

        # Remaining ResNet layers
        x = self.resnet34.layer2(x)
        
        x = self.resnet34.layer3(x)
        x = self.resnet34.layer4(x)

        x = self.resnet34.avgpool(x)
        x = torch.flatten(x, 1)
        logits = self.resnet34.fc(x)
        return logits

In [7]:
# --- Data Preparation ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
])

# Load CIFAR-10 as an example dataset
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
val_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [8]:
# --- Model, Loss, and Optimizer ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EarlyExitResNet34(num_classes=10, input_shape=(3, 224, 224)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [33]:
def train(model, train_loader, optimizer, criterion, device, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)

            # Compute dynamic thresholds for the batch
            thresholds = model._compute_dynamic_threshold(images)  # A list of thresholds for each sample in the batch

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass with thresholds
            outputs = model(images, thresholds=thresholds)

            # Compute loss
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Update metrics
            running_loss += loss.item()
            _, predicted = torch.max(outputs, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = correct / total * 100
        print(f"Epoch {epoch + 1}: Loss = {epoch_loss:.4f}, Accuracy = {epoch_accuracy:.2f}%")
    torch.save(model.state_dict(), "resnet34_early_exit_cifar10_adaptive_plsv3.pth")
    print("Model saved successfully!")

In [34]:
# --- Validation Function ---
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images, thresholds=None)  # No early exit during validation
            loss = criterion(outputs, labels)

            # Update metrics
            val_loss += loss.item()
            _, predicted = torch.max(outputs, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    val_loss /= len(val_loader)
    val_accuracy = correct / total * 100
    print(f"Validation: Loss = {val_loss:.4f}, Accuracy = {val_accuracy:.2f}%")
    return val_loss, val_accuracy

In [35]:
# --- Training and Validation Loop ---
for epoch in range(10):
    train(model, train_loader, optimizer, criterion, device, num_epochs=1)
    validate(model, val_loader, criterion, device)

  return n/db/n.sum(), bin_edges
Epoch 1/1: 100%|██████████| 391/391 [11:24<00:00,  1.75s/it]


Epoch 1: Loss = 22.8180, Accuracy = 27.56%
Model saved successfully!
Validation: Loss = 9.3171, Accuracy = 31.34%


Epoch 1/1: 100%|██████████| 391/391 [08:36<00:00,  1.32s/it]


Epoch 1: Loss = 6.4269, Accuracy = 34.20%
Model saved successfully!
Validation: Loss = 4.7840, Accuracy = 39.54%


Epoch 1/1: 100%|██████████| 391/391 [08:43<00:00,  1.34s/it]


Epoch 1: Loss = 3.5384, Accuracy = 42.21%
Model saved successfully!
Validation: Loss = 2.8325, Accuracy = 40.72%


Epoch 1/1: 100%|██████████| 391/391 [08:41<00:00,  1.33s/it]


Epoch 1: Loss = 2.1503, Accuracy = 49.06%
Model saved successfully!
Validation: Loss = 2.1258, Accuracy = 48.71%


Epoch 1/1: 100%|██████████| 391/391 [08:42<00:00,  1.34s/it]


Epoch 1: Loss = 1.6608, Accuracy = 52.90%
Model saved successfully!
Validation: Loss = 1.8034, Accuracy = 48.92%


Epoch 1/1: 100%|██████████| 391/391 [08:43<00:00,  1.34s/it]


Epoch 1: Loss = 1.3842, Accuracy = 56.48%
Model saved successfully!
Validation: Loss = 1.6645, Accuracy = 49.59%


Epoch 1/1: 100%|██████████| 391/391 [08:59<00:00,  1.38s/it]


Epoch 1: Loss = 1.2576, Accuracy = 58.62%
Model saved successfully!
Validation: Loss = 1.4751, Accuracy = 53.10%


Epoch 1/1: 100%|██████████| 391/391 [08:51<00:00,  1.36s/it]


Epoch 1: Loss = 1.1945, Accuracy = 60.09%
Model saved successfully!
Validation: Loss = 1.5035, Accuracy = 52.50%


Epoch 1/1: 100%|██████████| 391/391 [08:55<00:00,  1.37s/it]

Epoch 1: Loss = 1.1554, Accuracy = 61.07%
Model saved successfully!





Validation: Loss = 1.4115, Accuracy = 54.04%


Epoch 1/1: 100%|██████████| 391/391 [09:14<00:00,  1.42s/it]


Epoch 1: Loss = 1.1157, Accuracy = 62.34%
Model saved successfully!
Validation: Loss = 1.4583, Accuracy = 53.20%


In [4]:
def test_model(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    #early_exit_counts = [0] * (model.num_exits + 1)  # Count per exit point

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass with dynamic thresholds
            logits = model(images)

            # Count exits
            #early_exit_counts[exit_point] += logits.size(0)

            # Compute predictions
            _, predicted = torch.max(logits, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    # Overall accuracy
    accuracy = correct / total * 100
    print(f"Test Accuracy: {accuracy:.2f}%")

In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


Files already downloaded and verified


In [12]:
# --- Load Trained Model ---
model = EarlyExitResNet34(num_classes=10, input_shape=(3, 224, 224))
model.load_state_dict(torch.load("resnet34_early_exit_cifar10_adaptive.pth"))  # Replace with the correct path
model.to(device)

# --- Test the Model ---
test_model(model, test_loader, device)

  model.load_state_dict(torch.load("resnet34_early_exit_cifar10_adaptive.pth"))  # Replace with the correct path
Testing: 100%|██████████| 313/313 [04:39<00:00,  1.12it/s]

Test Accuracy: 91.83%



