In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import heapq
import pickle
import os
import re
from google.colab import drive
import networkx as nx
import matplotlib.pyplot as plt
from collections import deque
import numpy as np
import seaborn as sns

In [2]:
drive.mount('/content/gdrive', force_remount=True)

# Define function to load checkpoint
def load_checkpoint(model, save_dir, num_epochs, exact_checkpoint=None):
    os.makedirs(save_dir, exist_ok=True)
    checkpoint_files = [f for f in os.listdir(save_dir) if f.endswith('.pth')]
    training_finished = False

    if checkpoint_files:
        if "model_weights.pth" in checkpoint_files:
            latest_checkpoint_file = "model_weights.pth"
            training_finished = True
        elif exact_checkpoint is not None:
            latest_checkpoint_file = f"checkpoint_epoch_{exact_checkpoint}.pth"
        else:
            latest_checkpoint_file = max(checkpoint_files, key=lambda x: int(re.search(r'(\d+)', x).group()))
            print(f"Loading checkpoint: {latest_checkpoint_file}")

        model.load_state_dict(torch.load(os.path.join(save_dir, latest_checkpoint_file)))
        start_epoch = num_epochs if training_finished else int(re.search(r'(\d+)', latest_checkpoint_file).group()) + 1
    else:
        start_epoch = 0

    return start_epoch

Mounted at /content/gdrive


In [3]:
# Define ResNetTrainer
class ResNetTrainer:
    def __init__(self, x, epochs=20, batch_size=128, save_path=None):
        self.epochs = epochs
        self.batch_size = batch_size
        self.x = x  # Store the ResNet variant (e.g., 18, 34, 50, etc.)
        self.save_path = f"/content/gdrive/My Drive/checkpoints/resnet{x}_cifar100" if save_path is None else save_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._prepare_data()
        self._build_model()

    def _prepare_data(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        self.trainloader = DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=2)
        self.testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
        self.testloader = DataLoader(self.testset, batch_size=self.batch_size, shuffle=False, num_workers=2)

    def _build_model(self):
        # Select the appropriate ResNet model based on self.x
        if self.x == 18:
            self.model = torchvision.models.resnet18(pretrained=False, num_classes=100).to(self.device)
        elif self.x == 34:
            self.model = torchvision.models.resnet34(pretrained=False, num_classes=100).to(self.device)
        elif self.x == 50:
            self.model = torchvision.models.resnet50(pretrained=False, num_classes=100).to(self.device)
        elif self.x == 101:
            self.model = torchvision.models.resnet101(pretrained=False, num_classes=100).to(self.device)
        elif self.x == 152:
            self.model = torchvision.models.resnet152(pretrained=False, num_classes=100).to(self.device)
        else:
            raise ValueError(f"ResNet variant {self.x} is not supported")

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

    def train(self):
        start_epoch = load_checkpoint(self.model, self.save_path, self.epochs)
        for epoch in range(start_epoch, self.epochs):
            self.model.train()
            running_loss = 0.0
            for images, labels in self.trainloader:
                images, labels = images.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()
            print(f"Epoch {epoch+1}/{self.epochs}, Loss: {running_loss / len(self.trainloader)}")
            checkpoint_path = os.path.join(self.save_path, f'checkpoint_epoch_{epoch + 1}.pth')
            torch.save(self.model.state_dict(), checkpoint_path)

    def load_model(self, exact_checkpoint=None):
      if os.path.exists(self.save_path) and os.path.isdir(self.save_path):
          checkpoint_files = [f for f in os.listdir(self.save_path) if f.endswith('.pth')]
          if checkpoint_files:
              if "model_weights.pth" in checkpoint_files:
                  latest_checkpoint_file = "model_weights.pth"
              elif exact_checkpoint is not None:
                  latest_checkpoint_file = f"checkpoint_epoch_{exact_checkpoint}.pth"
              else:
                  latest_checkpoint_file = max(checkpoint_files, key=lambda x: int(re.search(r'(\d+)', x).group()))
              checkpoint_path = os.path.join(self.save_path, latest_checkpoint_file)
              print(f"Loading model from {checkpoint_path}")
              self.model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
              self.model.eval()
          else:
              raise FileNotFoundError("No checkpoints found in the directory.")
      else:
          raise FileNotFoundError("Saved model directory not found.")

In [4]:
# Define Complexity Calculator
class ComplexityCalculator:
    def __init__(self, model, dataloader, device):
        self.model = model
        self.dataloader = dataloader
        self.device = device

    def compute_complexities(self):
        self.model.eval()
        all_confidences = []
        correct_predictions = []
        complexities = []
        confidences = []
        predictions = []
        all_images = []
        all_labels = []

        with torch.no_grad():
            for images, labels in self.dataloader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                probabilities = torch.softmax(outputs, dim=1)
                conf, preds = torch.max(probabilities, dim=1)
                all_confidences.extend(conf.cpu().numpy())
                correct_predictions.extend((preds == labels).cpu().numpy())
                predictions.extend(preds.cpu().numpy())
                # true_labels.extend(labels.cpu().numpy())
                all_images.append(images.cpu())  # Move to CPU before storing
                all_labels.append(labels.cpu())

        # mean_confidence = sum(all_confidences) / len(all_confidences)
        # print("mean confidence is:"+ str(mean_confidence))
        test_images = torch.cat(all_images, dim=0)  # Concatenate into a single tensor
        test_labels = torch.cat(all_labels, dim=0)

        for conf, correct in zip(all_confidences, correct_predictions):
            if correct:
                complexity = 1 - (conf)
            else:
                complexity = 1 + conf
            complexities.append(complexity)
            confidences.append(conf)

        return complexities, confidences, predictions, test_images, test_labels

In [36]:
import json
import copy
# Load pre-trained ResNet models
# resnet18 = models.resnet18(pretrained=False)
# resnet101 = models.resnet101(pretrained=False)

def count_conv_layers(model):
    return sum(1 for layer in model.modules() if isinstance(layer, nn.Conv2d))

#def get_exit_layer(complexity, complexity_bins, num_conv_resnet18):
#    exit_index = np.digitize(complexity, complexity_bins) #-1
#    return num_conv_resnet18 + exit_index

def get_exit_layer(complexity, complexity_bins, min_layers):
    bin_index = np.digitize(complexity, complexity_bins) - 1
    return min_layers + bin_index

# Modify ResNet to allow early exiting
class EarlyExitResNet(nn.Module):
  def __init__(self, original_resnet, num_exit_layers, num_classes=100):
      super().__init__()

      self.num_exit_layers = num_exit_layers
      # original_resnet = copy.deepcopy(original_resnet)

      # Store basic parts of ResNet
      self.features = nn.Sequential()
      self.features.add_module('conv1', original_resnet.conv1)
      self.features.add_module('bn1', original_resnet.bn1)
      self.features.add_module('relu', original_resnet.relu)
      self.features.add_module('maxpool', original_resnet.maxpool)

      # Add layers up to the specified conv count
      layer_names = ['layer1', 'layer2', 'layer3', 'layer4']
      current_conv_count = 1  # Starting with 1 for the conv1 layer

      for layer_idx, layer_name in enumerate(layer_names):
          layer = getattr(original_resnet, layer_name)

          if isinstance(layer, nn.Sequential):
              for block_idx, block in enumerate(layer):
                  # Count convolutions in this block
                  block_conv_count = sum(1 for m in block.modules() if isinstance(m, nn.Conv2d))

                  # If adding this block would exceed our limit, break
                  if current_conv_count + block_conv_count > num_exit_layers:
                      break

                  self.features.add_module(f'layer{layer_idx+1}_block{block_idx}', block)
                  current_conv_count += block_conv_count

          # If we've reached our conv count limit, stop adding layers
          if current_conv_count >= num_exit_layers:
              break

      # Add global average pooling
      self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

      # Initialize FC to None - we'll set it after determining input size
      self.fc = None
      self.num_classes = num_classes
      self.in_features = None  # Store the input features for the FC layer

  def forward(self, x):
      x = self.features(x)
      x = self.avgpool(x)
      x = torch.flatten(x, 1)

      if self.fc is None:
          self.in_features = x.shape[1]
          self.fc = nn.Linear(self.in_features, self.num_classes).to(x.device)
          nn.init.kaiming_normal_(self.fc.weight, mode='fan_out', nonlinearity='relu')
          if self.fc.bias is not None:
              nn.init.constant_(self.fc.bias, 0)
          # print(f"Created new FC layer with input size {self.in_features} and output size {self.num_classes}")

      x = self.fc(x)
      return x

  def save_model(self, save_path, epoch=None):
      """Save the model weights to the specified path."""
      os.makedirs(save_path, exist_ok=True)

      if epoch is not None:
          checkpoint_path = os.path.join(save_path, f"exit_{self.num_exit_layers}_epoch_{epoch}.pth")
      else: #
          checkpoint_path = os.path.join(save_path, f"exit_{self.num_exit_layers}_weights.pth") #

      # Save model state dictionary
      torch.save(self.state_dict(), checkpoint_path)
      print(f"Model saved to {checkpoint_path}")

      # Also save the input feature dimension for the FC layer
      # config_path = os.path.join(save_path, f"exit_{self.num_exit_layers}_config.json")
      # config = {
      #     "num_exit_layers": self.num_exit_layers,
      #     "in_features": self.in_features,
      #     "num_classes": self.num_classes
      # }
      # with open(config_path, 'w') as f:
      #     json.dump(config, f)
      # print(f"Model configuration saved to {config_path}")

  @classmethod
  def load_model(cls, original_resnet, save_path, num_exit_layers, device='cuda', exact_checkpoint=None):
      """Load a model from the specified path."""
      # First, load the configuration
      # config_path = os.path.join(save_path, f"exit_{num_exit_layers}_config.json")
      # if os.path.exists(config_path):
      #     with open(config_path, 'r') as f:
      #         config = json.load(f)
      #     num_classes = config.get("num_classes", 100)
      # else:
      #     print("No configuration file found, using default values.")
      #     num_classes = 100

      # Create the model
      model = cls(original_resnet, num_exit_layers, 100) # CIFAR-100

      # Initialize the FC layer with a dummy forward pass
      dummy_input = torch.randn(1, 3, 32, 32).to(device)  # CIFAR-100 images
      _ = model(dummy_input)

      # Find the checkpoint file
      if os.path.exists(save_path) and os.path.isdir(save_path):
          checkpoint_files = [f for f in os.listdir(save_path)
                            if f.startswith(f"exit_{num_exit_layers}") and f.endswith('.pth')]

          if checkpoint_files:
              if f"exit_{num_exit_layers}_weights.pth" in checkpoint_files:
                  checkpoint_file = f"exit_{num_exit_layers}_weights.pth"
              elif exact_checkpoint is not None:
                  checkpoint_file = f"exit_{num_exit_layers}_epoch_{exact_checkpoint}.pth"
              else:
                  # Get the latest epoch checkpoint
                  epoch_checkpoints = [f for f in checkpoint_files if "epoch" in f]
                  if epoch_checkpoints:
                      checkpoint_file = max(epoch_checkpoints,
                                          key=lambda x: int(re.search(r'epoch_(\d+)', x).group(1)))
                  else:
                      checkpoint_file = checkpoint_files[0]

              checkpoint_path = os.path.join(save_path, checkpoint_file)
              print(f"Loading model from {checkpoint_path}")
              model.load_state_dict(torch.load(checkpoint_path, map_location=device))
              model.eval()

              match = re.search(r'epoch_(\d+)', checkpoint_file)
              start_epoch = int(match.group(1)) + 1 if match else 0

              return model.to(device), start_epoch
          else:
              raise FileNotFoundError(f"No checkpoints found for exit layer {num_exit_layers}.")
      else:
          raise FileNotFoundError("Saved model directory not found.")

def fine_tune_early_exit_model(original_model, train_loader, num_exit_layers, save_path,start_epoch,
                               num_classes=100, device='cuda', epochs=20, lr=0.001):
    # Create early exit model
    early_exit_model = EarlyExitResNet(original_model, num_exit_layers, num_classes).to(device)

    # Initialize the FC layer with a forward pass
    dummy_input = torch.randn(1, 3, 32, 32).to(device)  # Assuming CIFAR-100 images
    _ = early_exit_model(dummy_input)

    # Freeze all layers except the final FC layer
    for param in early_exit_model.parameters():
        param.requires_grad = False

    # Ensure the final FC layer's parameters are trainable
    for param in early_exit_model.fc.parameters():
        param.requires_grad = True

    # Set up optimizer and loss function
    optimizer = torch.optim.Adam(early_exit_model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Fine-tuning loop
    early_exit_model.train()
    for epoch in range(start_epoch, epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = early_exit_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        # Print statistics
        print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}, '
              f'Acc: {100.*correct/total:.2f}%')

        # Save checkpoint
        early_exit_model.save_model(save_path, epoch=epoch+1)

    # Save final model
    early_exit_model.save_model(save_path) #

    return early_exit_model

In [None]:
# Apply early exit logic to test images
complexity_trainer = ResNetTrainer(34, epochs=20, batch_size=128, save_path=f"/content/gdrive/My Drive/checkpoints/resnet34_cifar100")
complexity_trainer.train()
complexity_trainer.load_model(exact_checkpoint=20)

test_loader = DataLoader(complexity_trainer.testset, batch_size=128, shuffle=False, num_workers=2)
complexity_calculator = ComplexityCalculator(complexity_trainer.model, test_loader, complexity_trainer.device)
complexities, confidences, predictions, test_images, test_labels = complexity_calculator.compute_complexities()
complexities = np.array(complexities)

resnet101 = ResNetTrainer(101, epochs=100, batch_size=128, save_path=f"/content/gdrive/My Drive/checkpoints/resnet101_cifar100")
resnet101.train()
resnet101.load_model(exact_checkpoint=100)

resnet18 = ResNetTrainer(18, epochs=100, batch_size=128, save_path=f"/content/gdrive/My Drive/checkpoints/resnet18_cifar100")
resnet18.train()
resnet18.load_model()

num_conv_resnet18 = count_conv_layers(resnet18.model)
num_conv_resnet101 = count_conv_layers(resnet101.model)

x = num_conv_resnet101 - num_conv_resnet18
complexity_bins = np.linspace(0, 2, x + 1)


if isinstance(test_images, torch.Tensor):
    batch_images = test_images.to(resnet101.device)
else:
    batch_images = torch.stack(test_images).to(resnet101.device)

exit_layers = [get_exit_layer(c, complexity_bins, num_conv_resnet18) for c in complexities]

# Define save path for early exit models
early_exit_save_path = "/content/gdrive/My Drive/checkpoints/early_exit_models"
os.makedirs(early_exit_save_path, exist_ok=True)

# First, identify the unique exit layers
unique_exit_layers = sorted(set(exit_layers))

# Create a dictionary to store the fine-tuned models
fine_tuned_models = {}

# Fine-tune a model for each unique exit layer
for exit_layer in unique_exit_layers:
    try:
        # Try to load existing model
        print(f"Attempting to load model for exit layer {exit_layer}")
        model, latest_epoch = EarlyExitResNet.load_model(
            resnet101.model,
            early_exit_save_path,
            exit_layer,
            device=resnet101.device
        )
        fine_tune_early_exit_model(model, torch.utils.data.DataLoader(resnet101.trainset, batch_size=128, shuffle=True), exit_layer, early_exit_save_path, latest_epoch, epochs=15)
        fine_tuned_models[exit_layer] = model
        print(f"Successfully loaded model for exit layer {exit_layer}")
    except FileNotFoundError:
        # Fine-tune new model if not found
        print(f"Fine-tuning new model for exit layer {exit_layer}")
        fine_tuned_models[exit_layer] = fine_tune_early_exit_model(
            resnet101.model,
            torch.utils.data.DataLoader(resnet101.trainset, batch_size=128, shuffle=True),
            exit_layer,
            early_exit_save_path,
            start_epoch=0
        )

# Evaluate each fine-tuned model on the appropriate test samples
for unique_exit in unique_exit_layers:
    indices = [i for i, e in enumerate(exit_layers) if e == unique_exit]
    if not indices:
        continue

    sub_batch = batch_images[indices]
    sub_batch_labels = torch.tensor([test_labels[i] for i in indices]).to(resnet101.device)

    model = fine_tuned_models[unique_exit]
    model.eval()

    with torch.no_grad():
        outputs = model(sub_batch)

    _, predicted = torch.max(outputs.data, 1)
    correct = (predicted == sub_batch_labels).sum().item()
    accuracy = correct / len(indices)
    print(f"Exit layer {unique_exit} accuracy: {accuracy:.4f}")

Loading checkpoint: checkpoint_epoch_100.pth
Loading model from /content/gdrive/My Drive/checkpoints/resnet34_cifar100/checkpoint_epoch_20.pth
Loading checkpoint: checkpoint_epoch_100.pth
Loading model from /content/gdrive/My Drive/checkpoints/resnet101_cifar100/checkpoint_epoch_100.pth
Loading checkpoint: checkpoint_epoch_100.pth
Loading model from /content/gdrive/My Drive/checkpoints/resnet18_cifar100/checkpoint_epoch_100.pth
Attempting to load model for exit layer 20
Fine-tuning new model for exit layer 20
Epoch 1/20, Loss: 4.3362, Acc: 9.52%
Model saved to /content/gdrive/My Drive/checkpoints/early_exit_models/exit_20_epoch_1.pth
Epoch 2/20, Loss: 3.5796, Acc: 16.90%
Model saved to /content/gdrive/My Drive/checkpoints/early_exit_models/exit_20_epoch_2.pth
Epoch 3/20, Loss: 3.3778, Acc: 20.31%
Model saved to /content/gdrive/My Drive/checkpoints/early_exit_models/exit_20_epoch_3.pth
Epoch 4/20, Loss: 3.2562, Acc: 22.49%
Model saved to /content/gdrive/My Drive/checkpoints/early_exit_m

In [None]:
import numpy as np

resnet_predictions = {}
test_labels = None  # Will store labels (assumed to be the same for both)

resnet_variants = [18, 101]

for variant in resnet_variants:
    trainer = ResNetTrainer(variant, epochs=20, batch_size=128, save_path=f"/content/gdrive/My Drive/checkpoints/resnet{variant}_cifar100")
    trainer.train()
    trainer.load_model(exact_checkpoint=100)

    test_loader = DataLoader(trainer.testset, batch_size=128, shuffle=False, num_workers=2)
    complexity_calculator = ComplexityCalculator(trainer.model, test_loader, trainer.device)
    _, confidences, predictions, test_images, test_labels = complexity_calculator.compute_complexities()

    resnet_predictions[variant] = np.array(predictions)  # Store predictions

test_labels = np.array(test_labels)  # Store ground truth

# Ensure both models have the same number of predictions
assert resnet_predictions[18].shape == resnet_predictions[101].shape == test_labels.shape

# Correct classifications
correct_18 = resnet_predictions[18] == test_labels
correct_101 = resnet_predictions[101] == test_labels

# Category classification
correct_both = correct_18 & correct_101
wrong_both = ~correct_18 & ~correct_101
correct_18_wrong_101 = correct_18 & ~correct_101
correct_101_wrong_18 = correct_101 & ~correct_18

# Count occurrences
print(f"Correct in both: {correct_both.sum()}")
print(f"Wrong in both: {wrong_both.sum()}")
print(f"Correct in ResNet-18 but wrong in ResNet-101: {correct_18_wrong_101.sum()}")
print(f"Correct in ResNet-101 but wrong in ResNet-18: {correct_101_wrong_18.sum()}")

In [None]:
resnet_variants = [18, 34, 50]

for variant in resnet_variants:
    trainer = ResNetTrainer(variant, epochs=20, batch_size=128, save_path=f"/content/gdrive/My Drive/checkpoints/resnet{variant}_cifar100")
    trainer.train()
    trainer.load_model(exact_checkpoint=50)

    test_loader = DataLoader(trainer.testset, batch_size=128, shuffle=False, num_workers=2)
    complexity_calculator = ComplexityCalculator(trainer.model, test_loader, trainer.device)
    complexities, confidences, predictions, test_images, test_labels = complexity_calculator.compute_complexities()

    complexities = np.array(complexities)

    # Plot histogram
    plt.figure(figsize=(10, 5))
    sns.histplot(complexities, bins=50, kde=True)
    plt.xlabel("Complexity Score")
    plt.ylabel("Frequency")
    plt.title(f"Distribution of Photo Complexities in resnet{variant}")
    plt.show()