In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision import models

In [4]:
# Early Exit Class
class EarlyExit(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(EarlyExit, self).__init__()
        self.conv = nn.Conv2d(in_channels, 128, kernel_size=1, stride=1, bias=False)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.adaptive_avg_pool2d(x, (1, 1))  # Global average pooling
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [None]:
# ResNet with Early Exits
class ResNet34WithExits(nn.Module):
    def __init__(self, num_classes=10, exit_thresholds=(0.55,0.8)):
        super(ResNet34WithExits, self).__init__()
        # Define layers
        self.num_classes = num_classes
        self.exit_thresholds = exit_thresholds # confidence thresholds

        # Load the ResNet34 backbone
        resnet = models.resnet34(weights=None)
        self.base = nn.Sequential(*list(resnet.children())[:-2])  # Exclude FC and Adaptive Pooling layers

        # Early exit 1
        self.exit1 = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

        # Early exit 2
        self.exit2 = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

        # Final layers (ResNet output)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        
    def forward(self, x, train_mode=True):
        # Early exit 1
        x1 = self.base[:3](x)
        exit1_out = self.exit1(x1)

        # Early exit 2
        x2 = self.base[3:5](x1)
        exit2_out = self.exit2(x2)

        # Final exit
        x3 = self.base[5:](x2)
        x3 = self.avgpool(x3)
        x3 = torch.flatten(x3, 1)
        final_out = self.fc(x3)

        # If in training mode, return all exits
        if train_mode:
            return exit1_out, exit2_out, final_out

        # Otherwise, use threshold logic for inference
        confidence1 = F.softmax(exit1_out, dim=1).max(1)[0]
        if confidence1.mean().item() >= self.exit_thresholds[0]:
            return exit1_out, None, None

        confidence2 = F.softmax(exit2_out, dim=1).max(1)[0]
        if confidence2.mean().item() >= self.exit_thresholds[1]:
            return None, exit2_out, None

        return None, None, final_out

In [None]:
# Instantiate the model
model = ResNet34WithExits(num_classes=10)  # Change `num_classes` as per your dataset
print(model)

In [78]:
# Define transformations for CIFAR-100
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

# Load CIFAR-100 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Define model, loss function, optimizer, and scheduler
model = ResNet34WithExits(num_classes=10)  # ResNet34WithExits defined earlier
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

Files already downloaded and verified
Files already downloaded and verified


In [79]:
# Define training function
def train_model(model, train_loader, optimizer, criterion, device, epochs):
    model = model.to(device)
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass with early exits
            exit1_out, exit2_out, final_out = model(inputs)

            # Compute loss (using only the final output for simplicity)
            loss = criterion(final_out, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

    # Save the trained model
    torch.save(model.state_dict(), "resnet34_early_exit_cifar10_loweestT.pth")
    print("Model saved successfully!")


In [80]:
def train_model_with_all_exits(model, train_loader, optimizer, criterion, device, exit_weights, epochs=10):
    """
    Train a multi-exit model using a unified loss function.
    
    Parameters:
        model: The model with early exits.
        train_loader: DataLoader for the training data.
        optimizer: Optimizer for the model parameters.
        criterion: Loss function (e.g., CrossEntropyLoss).
        device: The device to run the training on ('cpu' or 'cuda').
        exit_weights: List of weights for the losses of each exit [w1, w2, w3].
        epochs: Number of training epochs.
    """
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        correct = [0, 0, 0]  # Track correct predictions for each exit
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass (train_mode=True ensures all exits produce outputs)
            exit1_out, exit2_out, final_out = model(inputs, train_mode=True)

            # Compute losses for each exit
            loss1 = criterion(exit1_out, labels)
            loss2 = criterion(exit2_out, labels)
            loss3 = criterion(final_out, labels)

            # Combine losses using exit weights
            total_loss = (exit_weights[0] * loss1 +
                          exit_weights[1] * loss2 +
                          exit_weights[2] * loss3)

            # Backward pass and optimizer step
            total_loss.backward()
            optimizer.step()

            # Update statistics
            running_loss += total_loss.item()
            total += labels.size(0)

            # Compute accuracy for each exit
            for i, output in enumerate([exit1_out, exit2_out, final_out]):
                _, predictions = torch.max(output, 1)
                correct[i] += (predictions == labels).sum().item()

        # Print epoch statistics
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = [100 * c / total if total > 0 else 0 for c in correct]
        print(f"Epoch [{epoch + 1}/{epochs}] - Loss: {epoch_loss:.4f}, "
              f"Accuracies: Exit 1: {epoch_acc[0]:.2f}%, Exit 2: {epoch_acc[1]:.2f}%, Final Exit: {epoch_acc[2]:.2f}%")
    torch.save(model.state_dict(), "resnet34_early_exit_cifar10_loweestT.pth")
    print("Model saved successfully!")

In [81]:
# Define weights for each exit's loss
exit_weights = [0.3, 0.5, 0.2]  # Example: Emphasize Exit 1 more

# Train the model
train_model_with_all_exits(
    model=model,
    train_loader=train_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    exit_weights=exit_weights,
    epochs=10
)

Epoch [1/10] - Loss: 1.8460, Accuracies: Exit 1: 24.01%, Exit 2: 33.71%, Final Exit: 35.52%
Epoch [2/10] - Loss: 1.5280, Accuracies: Exit 1: 30.77%, Exit 2: 48.69%, Final Exit: 50.29%
Epoch [3/10] - Loss: 1.3710, Accuracies: Exit 1: 32.65%, Exit 2: 57.21%, Final Exit: 57.98%
Epoch [4/10] - Loss: 1.2630, Accuracies: Exit 1: 34.11%, Exit 2: 62.82%, Final Exit: 63.10%
Epoch [5/10] - Loss: 1.1959, Accuracies: Exit 1: 35.07%, Exit 2: 66.27%, Final Exit: 66.36%
Epoch [6/10] - Loss: 1.1446, Accuracies: Exit 1: 36.12%, Exit 2: 68.95%, Final Exit: 68.96%
Epoch [7/10] - Loss: 1.1117, Accuracies: Exit 1: 36.42%, Exit 2: 70.53%, Final Exit: 70.14%
Epoch [8/10] - Loss: 1.0866, Accuracies: Exit 1: 36.99%, Exit 2: 71.77%, Final Exit: 71.46%
Epoch [9/10] - Loss: 1.0667, Accuracies: Exit 1: 37.18%, Exit 2: 72.66%, Final Exit: 72.24%
Epoch [10/10] - Loss: 1.0456, Accuracies: Exit 1: 37.77%, Exit 2: 73.49%, Final Exit: 73.09%
Model saved successfully!


In [71]:
def test_model_with_exit_tracking(model, test_loader, device, exit_thresholds=None):
    """
    Test the multi-exit model, calculate overall accuracy, and track exit usage.
    
    Parameters:
        model: The trained model with early exits.
        test_loader: DataLoader for the test dataset.
        device: The device to perform testing ('cpu' or 'cuda').
        exit_thresholds: List of thresholds for early exits (optional; defaults to model thresholds).
        
    Returns:
        overall_accuracy: The overall accuracy of the model across all exits.
        exit_usage: A list with the count of samples exiting at each exit.
    """
    model.eval()  # Set the model to evaluation mode
    exit_counts = [0, 0, 0]  # Track how many samples exit at each point
    correct_predictions = 0
    total_samples = 0

    # Optionally override the model's thresholds
    if exit_thresholds is not None:
        model.exit_thresholds = exit_thresholds

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass with early exit logic
            exit1_out, exit2_out, final_out = model(inputs, train_mode=False)

            # Determine which exit is used for each batch
            for i in range(inputs.size(0)):  # Iterate over batch size
                outputs = None
                if exit1_out is not None and exit1_out[i] is not None:
                    outputs = exit1_out[i]
                    exit_counts[0] += 1
                elif exit2_out is not None and exit2_out[i] is not None:
                    outputs = exit2_out[i]
                    exit_counts[1] += 1
                elif final_out is not None and final_out[i] is not None:
                    outputs = final_out[i]
                    exit_counts[2] += 1

                # Compute predictions
                if outputs is not None:
                    _, prediction = torch.max(outputs.unsqueeze(0), 1)
                    correct_predictions += (prediction == labels[i]).item()
                    total_samples += 1

    # Calculate overall accuracy
    overall_accuracy = 100 * correct_predictions / total_samples if total_samples > 0 else 0

    # Print results
    print(f"Overall Accuracy: {overall_accuracy:.2f}%")
    for i, count in enumerate(exit_counts, start=1):
        print(f"Exit {i} Usage: {count} samples ({100 * count / total_samples:.2f}%)")

    return overall_accuracy, exit_counts

In [85]:
#load the model 
model = ResNet34WithExits(num_classes=10, exit_thresholds=(0.55, 0.6))
model.load_state_dict(torch.load("resnet34_early_exit_cifar10_loweestT.pth"))
model = model.to(device)
# Test the model
accuracy, exit_usage = test_model_with_exit_tracking(
    model=model,
    test_loader=test_loader,
    device=device,
    exit_thresholds=[0.55, 0.7]  # Optional custom thresholds
)

  model.load_state_dict(torch.load("resnet34_early_exit_cifar10_loweestT.pth"))


Overall Accuracy: 69.83%
Exit 1 Usage: 0 samples (0.00%)
Exit 2 Usage: 10000 samples (100.00%)
Exit 3 Usage: 0 samples (0.00%)


this is where teh dapative starts


In [112]:
# ResNet with Early Exits(start of implementation)
class ResNet34WithExits(nn.Module):
    def __init__(self, num_classes=10, exit_thresholds=(0.55,0.8)):
        super(ResNet34WithExits, self).__init__()
        # Define layers
        self.num_classes = num_classes
        self.exit_thresholds = exit_thresholds # confidence thresholds

        # Load the ResNet34 backbone
        resnet = models.resnet34(weights=None)
        self.base = nn.Sequential(*list(resnet.children())[:-2])  # Exclude FC and Adaptive Pooling layers

        # Early exit 1
        self.exit1 = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

        # Early exit 2
        self.exit2 = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

        # Final layers (ResNet output)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x, train_mode=True):
        # Early exit 1
        x1 = self.base[:3](x)
        exit1_out = self.exit1(x1)

        # Early exit 2
        x2 = self.base[3:5](x1)
        exit2_out = self.exit2(x2)

        # Final exit
        x3 = self.base[5:](x2)
        x3 = self.avgpool(x3)
        x3 = torch.flatten(x3, 1)
        final_out = self.fc(x3)

        # If in training mode, return all exits
        if train_mode:
            return exit1_out, exit2_out, final_out

        # Otherwise, use threshold logic for inference
        confidence1 = F.softmax(exit1_out, dim=1).max(1)[0]
        if confidence1.mean().item() >= self.exit_thresholds[0]:
            return exit1_out, None, None

        confidence2 = F.softmax(exit2_out, dim=1).max(1)[0]
        if confidence2.mean().item() >= self.exit_thresholds[1]:
            return None, exit2_out, None

        return None, None, final_out

In [113]:
import numpy as np
import cv2
from skimage.filters import sobel
from skimage.measure import shannon_entropy

def compute_texture_variance(image):
    """Compute the variance of pixel intensities in the image."""
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    return np.var(gray)

def compute_edge_density(image):
    """Compute the density of edges in the image using Sobel filter."""
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    edges = sobel(gray)  # Sobel edge detection
    return np.mean(edges > 0.1)  # Ratio of edge pixels

def compute_color_variance(image):
    """Compute the variance of the color channels."""
    return np.var(image, axis=(0, 1)).mean()  # Variance across all channels

def compute_entropy(image):
    """Compute the entropy of the image."""
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    return shannon_entropy(gray)

def adjust_thresholds(image, base_threshold):
    texture_var = compute_texture_variance(image)
    edge_density = compute_edge_density(image)
    color_var = compute_color_variance(image)
    entropy = compute_entropy(image)

    adjusted_thresholds = [
        base_threshold[0] + 0.01 * texture_var,
        base_threshold[1] + 0.02 * edge_density - 0.1 * color_var,
        base_threshold[2] + 0.01 * entropy
    ]

    adjusted_thresholds = [min(max(t,0),1) for t in adjusted_thresholds]
    return adjusted_thresholds

#model = ResNet34WithExits(num_classes=10, exit_thresholds=(e1, e2))  # Change `num_classes` as per your dataset
#print(model)

In [114]:
# Define transformations for CIFAR-100
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

# Load CIFAR-100 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Define model, loss function, optimizer, and scheduler
model = ResNet34WithExits(num_classes=10)  # ResNet34WithExits defined earlier
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

Files already downloaded and verified
Files already downloaded and verified


In [115]:
def train_model_dynamic_thresholds(
    model, train_loader, optimizer, criterion, device, base_thresholds, exit_weights, epochs
):
    """
    Train the model with dynamic thresholds based on image properties.
    
    Parameters:
        model: The model with early exits.
        train_loader: DataLoader for the training dataset.
        optimizer: Optimizer for training.
        criterion: Loss function (e.g., CrossEntropyLoss).
        device: Device to train on ('cpu' or 'cuda').
        base_thresholds: Base thresholds for dynamic adjustment.
        exit_weights: Weights for losses of each exit.
        epochs: Number of training epochs.
    
    Returns:
        None
    """
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total_samples = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Initialize total loss for this batch
            total_loss = 0.0

            # Process each sample in the batch to calculate thresholds
            for i in range(inputs.size(0)):
                # Convert the image to numpy (H, W, C) for property calculations
                image = inputs[i].permute(1, 2, 0).cpu().numpy()
                adjusted_thresholds = adjust_thresholds(image, base_thresholds)

                # Update the model's thresholds
                model.exit_thresholds = adjusted_thresholds

                # Forward pass for the sample
                exit1_out, exit2_out, final_out = model(inputs[i].unsqueeze(0), train_mode=True)

                '''
                # Compute losses for all exits
                if exit1_out is not None:
                    loss1 = criterion(exit1_out, labels[i].unsqueeze(0))
                    total_loss += exit_weights[0] * loss1
                if exit2_out is not None:
                    loss2 = criterion(exit2_out, labels[i].unsqueeze(0))
                    total_loss += exit_weights[1] * loss2
                if final_out is not None:
                    loss3 = criterion(final_out, labels[i].unsqueeze(0))
                    total_loss += exit_weights[2] * loss3
                '''
                # Compute losses for each exit
                loss1 = criterion(exit1_out, labels)
                loss2 = criterion(exit2_out, labels)
                loss3 = criterion(final_out, labels)

                # Combine losses using exit weights
                total_loss = (exit_weights[0] * loss1 +
                            exit_weights[1] * loss2 +
                            exit_weights[2] * loss3)
                
                # Backpropagation and optimization
                total_loss.backward()
                optimizer.step()

                # Update running loss
                running_loss += total_loss.item()
                total_samples += inputs.size(0)

        # Print epoch summary
        print(
            f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / total_samples:.4f}"
        )
    torch.save(model.state_dict(), "resnet34_early_exit_cifar10_loweestT.pth")
    print("Model saved successfully!")
    

In [116]:
# Define parameters
base_thresholds = [0.5, 0.6, 0.7]  # Base thresholds for dynamic adjustment
exit_weights = [0.4, 0.3, 0.3]  # Weights for exit losses
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10

# Train the model
train_model_dynamic_thresholds(
    model=model,
    train_loader=train_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    base_thresholds=base_thresholds,
    exit_weights=exit_weights,
    epochs=epochs
)

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])