In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
class EarlyExitBlock(nn.Module): 
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_channels, num_classes)
    
    def forward(self, x):
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        softmax_output = F.softmax(logits, dim=1)
        confidence, _ = torch.max(softmax_output, dim=1)  # Get confidence
        return logits, confidence

In [3]:
# Define a CNN that will serve as one stage.
class StageCNN(nn.Module):
    def __init__(self, channels=16):
        super(StageCNN, self).__init__()
        # Two convolutional layers with ReLU activations.
        self.conv1 = nn.Conv2d(channels, channels*2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels*2, channels, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)                          

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        return x

In [4]:
class EarlyExitNetwork(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding=1)
        self.layer1 = StageCNN(32)
        self.exit1 = EarlyExitBlock(32, num_classes)
        self.layer2 = StageCNN(32)
        self.exit2 = EarlyExitBlock(32, num_classes)
        self.layer3 = StageCNN(32)
        self.final_exit = EarlyExitBlock(32, num_classes)

    def forward(self, x, confidence_threshold=0.7):
        x = F.relu(self.conv1(x))
        x = F.relu(self.layer1(x))
        exit1_logits, exit1_conf = self.exit1(x)
        if not self.training and exit1_conf.max().item() > confidence_threshold:
            return exit1_logits, 1
        
        x = F.relu(self.layer2(x))
        exit2_logits, exit2_conf = self.exit2(x)
        if not self.training and exit2_conf.max().item() > confidence_threshold:
            return exit2_logits, 2
        
        x = F.relu(self.layer3(x))
        final_logits, _ = self.final_exit(x)
        if not self.training:
            return final_logits, 3
        
        return [exit1_logits, exit2_logits, final_logits], None

In [5]:
class MultiExitLoss(nn.Module):
    def __init__(self, weights=[0.2, 0.2, 0.6]):
        super().__init__()
        self.weights = weights
        self.ce = nn.CrossEntropyLoss()

    def forward(self, exits, target):
        loss = 0.0
        for i, exit_logits in enumerate(exits):
            loss += self.weights[i] * self.ce(exit_logits, target)
        return loss

In [6]:
# Data Augmentation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [7]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model = EarlyExitNetwork().to(device)
criterion = MultiExitLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [10]:
model.train()
for epoch in range(10):
    total_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        exits, _ = model(inputs)  # Get all exits
        loss = criterion(exits, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

Epoch 1, Loss: 0.0456
Epoch 2, Loss: 0.0443
Epoch 3, Loss: 0.0394
Epoch 4, Loss: 0.0368
Epoch 5, Loss: 0.0332
Epoch 6, Loss: 0.0334
Epoch 7, Loss: 0.0300
Epoch 8, Loss: 0.0285
Epoch 9, Loss: 0.0276
Epoch 10, Loss: 0.0250


In [11]:
confidence_threshold = 0.7
model.eval()
correct, total, exits_used = 0, 0, [0, 0, 0]
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        logits, exit_idx = model(inputs, confidence_threshold)
        _, predicted = torch.max(logits, 1)
        correct += (predicted == targets).sum().item()
        total += targets.size(0)
        exits_used[exit_idx-1] += 1  # exit_idx is 1,2,3
accuracy = 100 * correct / total
exit_rates = [100 * e / total for e in exits_used]

print(f"Test Accuracy: {accuracy:.2f}%")
print(f"Exit 1: {exit_rates[0]:.1f}%, Exit 2: {exit_rates[1]:.1f}%, Final Exit: {exit_rates[2]:.1f}%")

Test Accuracy: 98.49%
Exit 1: 96.5%, Exit 2: 3.2%, Final Exit: 0.3%
