In [11]:
import torchvision.transforms as transforms
import torchvision
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.utils import save_image
import os, cv2, sys, time, math, os
from torchvision import transforms, utils, datasets
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, WeightedRandomSampler
from torch.utils.data import random_split
from PIL import Image
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from itertools import product
import pandas as pd
import torchvision.models as models
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
!pip install pthflops
from pthflops import count_ops
from pthflops import count_ops
from torch import Tensor
from typing import Callable, Any, Optional, List
import functools



In [12]:
class LoadDataset():
  def __init__(self, input_dim, batch_size_train, batch_size_test, save_idx, model_id, seed=42):
    self.input_dim = input_dim
    self.batch_size_train = batch_size_train
    self.batch_size_test = batch_size_test
    self.seed = seed
    self.save_idx = save_idx
    self.model_id = model_id

    #To normalize the input images data.
    mean = [0.457342265910642, 0.4387686270106377, 0.4073427106250871]
    std = [0.26753769276329037, 0.2638145880487105, 0.2776826934044154]

    # Note that we apply data augmentation in the training dataset.
    self.transformations_train = transforms.Compose([transforms.Resize((input_dim, input_dim)),
                                                     transforms.RandomChoice([
                                                                              transforms.ColorJitter(brightness=(0.80, 1.20)),
                                                                              transforms.RandomGrayscale(p = 0.25)]),
                                                     transforms.RandomHorizontalFlip(p = 0.25),
                                                     transforms.RandomRotation(25),
                                                     transforms.ToTensor(), 
                                                     transforms.Normalize(mean = mean, std = std),
                                                     ])

    # Note that we do not apply data augmentation in the test dataset.
    self.transformations_test = transforms.Compose([
                                                     transforms.Resize(input_dim), 
                                                     transforms.ToTensor(), 
                                                     transforms.Normalize(mean = mean, std = std),
                                                     ])

  def cifar_10(self, root_path, split_ratio):
    # This method loads Cifar-10 dataset. 
    
    # saves the seed
    torch.manual_seed(self.seed)

    # This downloads the training and test CIFAR-10 datasets and also applies transformation  in the data.
    train_set = datasets.CIFAR10(root=root_path, train=True, download=True, transform=self.transformations_train)
    test_set = datasets.CIFAR10(root=root_path, train=False, download=True, transform=self.transformations_test)

    classes_list = train_set.classes

    # This line defines the size of validation dataset.
    val_size = int(split_ratio*len(train_set))

    # This line defines the size of training dataset.
    train_size = int(len(train_set) - val_size)

    #This line splits the training dataset into train and validation, according split ratio provided as input.
    train_dataset, val_dataset = random_split(train_set, [train_size, val_size])

    #This block creates data loaders for training, validation and test datasets.
    train_loader = DataLoader(train_dataset, self.batch_size_train, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, self.batch_size_test, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_set, self.batch_size_test, num_workers=4, pin_memory=True)

    return train_loader, val_loader, test_loader

  def cifar_100(self, root_path, split_ratio):
    # This method loads Cifar-100 dataset
    root = "cifar_100"
    torch.manual_seed(self.seed)

    # This downloads the training and test Cifar-100 datasets and also applies transformation  in the data.
    train_set = datasets.CIFAR100(root=root_path, train=True, download=True, transform=self.transformations_train)
    test_set = datasets.CIFAR100(root=root_path, train=False, download=True, transform=self.transformations_train)

    classes_list = train_set.classes

    # This line defines the size of validation dataset.
    val_size = int(split_ratio*len(train_set))

    # This line defines the size of training dataset.
    train_size = int(len(train_set) - val_size)

    #This line splits the training dataset into train and validation, according split ratio provided as input.
    train_dataset, val_dataset = random_split(train_set, [train_size, val_size])

    #This block creates data loaders for training, validation and test datasets.
    train_loader = DataLoader(train_dataset, self.batch_size_train, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, self.batch_size_test, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_set, self.batch_size_test, num_workers=4, pin_memory=True)

    return train_loader, val_loader, test_loader
  
  def get_indices(self, dataset, split_ratio):
    nr_samples = len(dataset)
    indices = list(range(nr_samples))
    
    train_size = nr_samples - int(np.floor(split_ratio * nr_samples))

    np.random.shuffle(indices)

    train_idx, test_idx = indices[:train_size], indices[train_size:]

    return train_idx, test_idx

  def caltech_256(self, root_path, split_ratio, savePath_idx):
    # This method loads the Caltech-256 dataset.

    torch.manual_seed(self.seed)
    np.random.seed(seed=self.seed)

    # This block receives the dataset path and applies the transformation data. 
    train_set = datasets.ImageFolder(root_path, transform=self.transformations_train)

    val_set = datasets.ImageFolder(root_path, transform=self.transformations_test)
    test_set = datasets.ImageFolder(root_path, transform=self.transformations_test)

    if (os.path.exists(os.path.join(savePath_idx, "training_idx_caltech256_id_%s.npy"%(self.model_id)))):
      
      train_idx = np.load(os.path.join(savePath_idx, "training_idx_caltech256_id_%s.npy"%(self.model_id)))
      val_idx = np.load(os.path.join(savePath_idx, "validation_idx_caltech256_id_%s.npy"%(self.model_id)))
      test_idx = np.load(os.path.join(savePath_idx, "test_idx_caltech256_id_%s.npy"%(self.model_id)))

    else:

      # This line get the indices of the samples which belong to the training dataset and test dataset. 
      train_idx, test_idx = self.get_indices(train_set, split_ratio)

      # This line mounts the training and test dataset, selecting the samples according indices. 
      train_data = torch.utils.data.Subset(train_set, indices=train_idx)
      ##essa linha parecia estar faltando. copiei da versão anterior##

      # This line gets the indices to split the train dataset into training dataset and validation dataset.
      train_idx, val_idx = self.get_indices(train_data, split_ratio)

      np.save(os.path.join(savePath_idx, "traning_idx_caltech256_id_%s.npy"%(self.model_id)), train_idx)
      np.save(os.path.join(savePath_idx, "validation_idx_caltech256_id_%s.npy"%(self.model_id)), val_idx)
      np.save(os.path.join(savePath_idx, "test_idx_caltech256_id_%s.npy"%(self.model_id)), test_idx)

    # This line mounts the training and test dataset, selecting the samples according indices. 
    train_data = torch.utils.data.Subset(train_set, indices=train_idx)
    val_data = torch.utils.data.Subset(val_set, indices=val_idx)
    test_data = torch.utils.data.Subset(test_set, indices=test_idx)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=self.batch_size_train, shuffle=True, num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=self.batch_size_test, num_workers=4)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=self.batch_size_test, num_workers=4)

    return train_loader, val_loader, test_loader 

  def getDataset(self, root_path, dataset_name, split_ratio, savePath_idx):
    self.dataset_name = dataset_name
    def func_not_found():
      print("No dataset %s is found"%(self.dataset_name))

    func_name = getattr(self, self.dataset_name, func_not_found)
    train_loader, val_loader, test_loader = func_name(root_path, split_ratio, savePath_idx)
    return train_loader, val_loader, test_loader

In [13]:
def load_early_exit_dnn_model(model, model_path, device):
  
  model.load_state_dict(torch.load(model_path, map_location=device)["model_state_dict"])

  return model

def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  """3x3 convolution with padding"""
  return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)

class BasicBlock(nn.Module):
  """Basic Block defition.
  Basic 3X3 convolution blocks for use on ResNets with layers <= 34.
  Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
  """
  expansion = 1

  def __init__(self, inplanes, planes, stride=1, downsample=None):
    super(BasicBlock, self).__init__()
    self.conv1 = conv3x3(inplanes, planes, stride)
    self.bn1 = nn.BatchNorm2d(planes)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = conv3x3(planes, planes)
    self.bn2 = nn.BatchNorm2d(planes)
    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    identity = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
      identity = self.downsample(x)

    out += identity
    out = self.relu(out)

    return out


def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBNActivation(nn.Sequential):
    def __init__(
        self,
        in_planes: int,
        out_planes: int,
        kernel_size: int = 3,
        stride: int = 1,
        groups: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        activation_layer: Optional[Callable[..., nn.Module]] = None,
        dilation: int = 1,
    ) -> None:
        padding = (kernel_size - 1) // 2 * dilation
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.ReLU6
        super().__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
                      bias=False),
            norm_layer(out_planes),
            activation_layer(inplace=True)
        )
        self.out_channels = out_planes


# necessary for backwards compatibility
ConvBNReLU = ConvBNActivation


class InvertedResidual(nn.Module):
    def __init__(
        self,
        inp: int,
        oup: int,
        stride: int,
        expand_ratio: int,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers: List[nn.Module] = []
        if expand_ratio != 1:
            # pw
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
        layers.extend([
            # dw
            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            norm_layer(oup),
        ])
        self.conv = nn.Sequential(*layers)
        self.out_channels = oup
        self._is_cn = stride > 1

    def forward(self, x: Tensor) -> Tensor:
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class EarlyExitBlock(nn.Module):
  """
  This EarlyExitBlock allows the model to terminate early when it is confident for classification.
  """
  def __init__(self, input_shape, n_classes, exit_type, device):
    super(EarlyExitBlock, self).__init__()
    self.input_shape = input_shape

    _, channel, width, height = input_shape
    self.expansion = width * height if exit_type == 'plain' else 1

    self.layers = nn.ModuleList()

    if (exit_type == 'bnpool'):
      self.layers.append(nn.BatchNorm2d(channel))

    if (exit_type != 'plain'):
      self.layers.append(nn.AdaptiveAvgPool2d(1))
    
    #This line defines the data shape that fully-connected layer receives.
    current_channel, current_width, current_height = self.get_current_data_shape()

    self.layers = self.layers.to(device)

    #This line builds the fully-connected layer
    self.classifier = nn.Sequential(nn.Linear(current_channel*current_width*current_height, n_classes)).to(device)

    self.softmax_layer = nn.Softmax(dim=1)


  def get_current_data_shape(self):
    _, channel, width, height = self.input_shape
    temp_layers = nn.Sequential(*self.layers)

    input_tensor = torch.rand(1, channel, width, height)
    _, output_channel, output_width, output_height = temp_layers(input_tensor).shape
    return output_channel, output_width, output_height
        
  def forward(self, x):
    for layer in self.layers:
      x = layer(x)
    x = x.view(x.size(0), -1)
    output = self.classifier(x)
    #confidence = self.softmax_layer()
    return output

def conv1x1(in_planes, out_planes, stride=1):
  """1x1 convolution"""
  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Early_Exit_DNN(nn.Module):
  def __init__(self, model_name: str, n_classes: int, 
               pretrained: bool, n_branches: int, input_shape:tuple, 
               exit_type: str, device, distribution="linear"):
    super(Early_Exit_DNN, self).__init__()

    """
    This classes builds an early-exit DNNs architectures
    Args:

    model_name: model name 
    n_classes: number of classes in a classification problem, according to the dataset
    pretrained: 
    n_branches: number of branches (early exits) inserted into middle layers
    input_shape: shape of the input image
    exit_type: type of the exits
    distribution: distribution method of the early exit blocks.
    device: indicates if the model will processed in the cpu or in gpu
    
    Note: the term "backbone model" refers to a regular DNN model, considering no early exits.

    """
    self.model_name = model_name
    self.n_classes = n_classes
    self.pretrained = pretrained
    self.n_branches = n_branches
    self.input_shape = input_shape
    self.exit_type = exit_type
    self.distribution = distribution
    self.device = device
    self.channel, self.width, self.height = input_shape


    build_early_exit_dnn = self.select_dnn_architecture_model()

    build_early_exit_dnn()

  def select_dnn_architecture_model(self):
    """
    This method selects the backbone to insert the early exits.
    """

    architecture_dnn_model_dict = {"alexnet": self.early_exit_alexnet,
                                   "mobilenet": self.early_exit_mobilenet,
                                   "resnet18": self.early_exit_resnet18,
                                   "resnet34": self.early_exit_resnet34}

    return architecture_dnn_model_dict.get(self.model_name, self.invalid_model)

  def select_distribution_method(self):
    """
    This method selects the distribution method to insert early exits into the middle layers.
    """
    distribution_method_dict = {"linear":self.linear_distribution,
                                "pareto":self.paretto_distribution,
                                "fibonacci":self.fibo_distribution}
    return distribution_method_dict.get(self.distribution, self.invalid_distribution)
    
  def linear_distribution(self, i):
    """
    This method defines the Flops to insert an early exits, according to a linear distribution.
    """
    flop_margin = 1.0 / (self.n_branches+1)
    return self.total_flops * flop_margin * (i+1)

  def paretto_distribution(self, i):
    """
    This method defines the Flops to insert an early exits, according to a pareto distribution.
    """
    return self.total_flops * (1 - (0.8**(i+1)))

  def fibo_distribution(self, i):
    """
    This method defines the Flops to insert an early exits, according to a fibonacci distribution.
    """
    gold_rate = 1.61803398875
    return total_flops * (gold_rate**(i - self.num_ee))

  def verifies_nr_exits(self, backbone_model):
    """
    This method verifies if the number of early exits provided is greater than a number of layers in the backbone DNN model.
    """
    
    total_layers = len(list(backbone_model.children()))
    if (self.n_branches >= total_layers):
      raise Exception("The number of early exits is greater than number of layers in the DNN backbone model.")

  def countFlops(self, model):
    """
    This method counts the numper of Flops in a given full DNN model or intermediate DNN model.
    """
    input = torch.rand(1, self.channel, self.width, self.height)
    flops, all_data = count_ops(model, input, print_readable=False, verbose=False)
    return flops

  def where_insert_early_exits(self):
    """
    This method defines where insert the early exits, according to the dsitribution method selected.
    Args:

    total_flops: Flops of the backbone (full) DNN model.
    """
    threshold_flop_list = []
    distribution_method = self.select_distribution_method()

    for i in range(self.n_branches):
      threshold_flop_list.append(distribution_method(i))

    return threshold_flop_list

  def invalid_model(self):
    raise Exception("This DNN model has not implemented yet.")
  def invalid_distribution(self):
    raise Exception("This early-exit distribution has not implemented yet.")

  def is_suitable_for_exit(self):
    """
    This method answers the following question. Is the position to place an early exit?
    """
    intermediate_model = nn.Sequential(*(list(self.stages)+list(self.layers)))
    current_flop = self.countFlops(intermediate_model)
    return self.stage_id < self.n_branches and current_flop >= self.threshold_flop_list[self.stage_id]

  def add_exit_block(self):
    """
    This method adds an early exit in the suitable position.
    """
    input_tensor = torch.rand(1, self.channel, self.width, self.height)

    self.stages.append(nn.Sequential(*self.layers))

    feature_shape = nn.Sequential(*self.stages)(input_tensor).shape

    self.exits.append(EarlyExitBlock(feature_shape, self.n_classes, self.exit_type, self.device).to(self.device))
    self.layers = nn.ModuleList()
    self.stage_id += 1    

  def set_device(self):
    """
    This method sets the device that will run the DNN model.
    """

    self.stages.to(self.device)
    self.exits.to(self.device)
    self.layers.to(self.device)
    self.classifier.to(self.device)


  def early_exit_alexnet(self):
    """
    This method inserts early exits into a Alexnet model
    """

    self.stages = nn.ModuleList()
    self.exits = nn.ModuleList()
    self.layers = nn.ModuleList()
    self.cost = []
    self.stage_id = 0

    # Loads the backbone model. In other words, Alexnet architecture provided by Pytorch.
    backbone_model = models.alexnet(self.pretrained)

    # It verifies if the number of early exits provided is greater than a number of layers in the backbone DNN model.
    self.verifies_nr_exit_alexnet(backbone_model.features)
    
    # This obtains the flops total of the backbone model
    self.total_flops = self.countFlops(backbone_model)

    # This line obtains where inserting an early exit based on the Flops number and accordint to distribution method
    self.threshold_flop_list = self.where_insert_early_exits()

    for layer in backbone_model.features:
      self.layers.append(layer)
      if (isinstance(layer, nn.ReLU)) and (self.is_suitable_for_exit()):
        self.add_exit_block()

    
    
    self.layers.append(nn.AdaptiveAvgPool2d(output_size=(6, 6)))
    self.stages.append(nn.Sequential(*self.layers))

    
    self.classifier = backbone_model.classifier
    self.classifier[6] = nn.Linear(in_features=4096, out_features=self.n_classes, bias=True)
    self.softmax = nn.Softmax(dim=1)
    self.set_device()

  def verifies_nr_exit_alexnet(self, backbone_model):
    """
    This method verifies if the number of early exits provided is greater than a number of layers in the backbone DNN model.
    In AlexNet, we consider a convolutional block composed by: Convolutional layer, ReLU and he Max-pooling layer.
    Hence, we consider that it makes no sense to insert side branches between these layers or only after the convolutional layer.
    """

    count_relu_layer = 0
    for layer in backbone_model:
      if (isinstance(layer, nn.ReLU)):
        count_relu_layer += 1

    if (count_relu_layer > self.n_branches):
      raise Exception("The number of early exits is greater than number of layers in the DNN backbone model.")

  def early_exit_resnet18(self):
    """
    This method inserts early exits into a Resnet18 model
    """

    self.stages = nn.ModuleList()
    self.exits = nn.ModuleList()
    self.layers = nn.ModuleList()
    self.cost = []
    self.stage_id = 0

    self.inplanes = 64

    n_blocks = 4

    backbone_model = models.resnet18(self.pretrained)

    # It verifies if the number of early exits provided is greater than a number of layers in the backbone DNN model.
    self.verifies_nr_exits(backbone_model)

    # This obtains the flops total of the backbone model
    self.total_flops = self.countFlops(backbone_model)

    # This line obtains where inserting an early exit based on the Flops number and accordint to distribution method
    self.threshold_flop_list = self.where_insert_early_exits()

    building_first_layer = ["conv1", "bn1", "relu", "maxpool"]
    for layer in building_first_layer:
      self.layers.append(getattr(backbone_model, layer))

    if (self.is_suitable_for_exit()):
      self.add_exit_block()

    for i in range(1, n_blocks+1):
      
      block_layer = getattr(backbone_model, "layer%s"%(i))

      for l in block_layer:
        self.layers.append(l)

        if (self.is_suitable_for_exit()):
          self.add_exit_block()
    
    self.layers.append(nn.AdaptiveAvgPool2d(1))
    self.classifier = nn.Sequential(nn.Linear(512, self.n_classes))
    self.stages.append(nn.Sequential(*self.layers))
    self.softmax = nn.Softmax(dim=1)
    self.set_device()

  def early_exit_resnet34(self):
    return True
  

  def early_exit_mobilenet(self):
    """
    This method inserts early exits into a Mobilenet V2 model
    """

    self.stages = nn.ModuleList()
    self.exits = nn.ModuleList()
    self.layers = nn.ModuleList()
    self.cost = []
    self.stage_id = 0

    last_channel = 1280
    
    # Loads the backbone model. In other words, Mobilenet architecture provided by Pytorch.
    backbone_model = models.mobilenet_v2(self.pretrained)

    # It verifies if the number of early exits provided is greater than a number of layers in the backbone DNN model.
    self.verifies_nr_exits(backbone_model.features)
    
    # This obtains the flops total of the backbone model
    self.total_flops = self.countFlops(backbone_model)

    # This line obtains where inserting an early exit based on the Flops number and accordint to distribution method
    self.threshold_flop_list = self.where_insert_early_exits()

    for i, layer in enumerate(backbone_model.features.children()):
      
      self.layers.append(layer)    
      if (self.is_suitable_for_exit()):
        self.add_exit_block()

    self.layers.append(nn.AdaptiveAvgPool2d(1))
    self.stages.append(nn.Sequential(*self.layers))
    

    self.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(last_channel, self.n_classes),)

    self.set_device()
    self.softmax = nn.Softmax(dim=1)

  def forwardTrain(self, x):
    """
    This method is used to train the early-exit DNN model
    """
    
    output_list, conf_list, class_list  = [], [], []

    for i, exitBlock in enumerate(self.exits):
      
      x = self.stages[i](x)
      output_branch = exitBlock(x)
      output_list.append(output_branch)

      #Confidence is the maximum probability of belongs one of the predefined classes and inference_class is the argmax
      conf, infered_class = torch.max(self.softmax(output_branch), 1)
      conf_list.append(conf)
      class_list.append(infered_class)

    x = self.stages[-1](x)

    x = torch.flatten(x, 1)

    output = self.classifier(x)
    infered_conf, infered_class = torch.max(self.softmax(output), 1)
    output_list.append(output)
    conf_list.append(infered_conf)
    class_list.append(infered_class)

    return output_list, conf_list, class_list

  def temperature_scale_overall(self, logits, temp_overall):
    temperature = temp_overall.unsqueeze(1).expand(logits.size(0), logits.size(1)).to(self.device)
    return logits / temperature

  def temperature_scale_branches(self, logits, temp_branches, exit_branch):
    temperature = temp_branches[exit_branch].unsqueeze(1).expand(logits.size(0), logits.size(1)).to(self.device)
    return logits / temperature

  def forward_inference_calib_overall(self, x, p_tar, temp_overall):
    """
    This method is used to experiment of early-exit DNNs with overall calibration.
    """
    """
    This method is used to train the early-exit DNN model
    """
    output_list, conf_list, class_list  = [], [], []

    for i, exitBlock in enumerate(self.exits):
      x = self.stages[i](x)

      output_branch = exitBlock(x)
      output_branch = self.temperature_scale_overall(output_branch, temp_overall)
      conf, infered_class = torch.max(self.softmax(output_branch), 1)

      # Note that if confidence value is greater than a p_tar value, we terminate the dnn inference and returns the output
      if (conf.item() >= p_tar):
        return output_branch, conf.item(), infered_class, i

      else:
        output_list.append(output_branch)
        conf_list.append(conf.item())
        class_list.append(infered_class)

    x = self.stages[-1](x)
    
    x = torch.flatten(x, 1)

    output = self.classifier(x)
    output = self.temperature_scale_overall(output, temp_overall)
    conf, infered_class = torch.max(self.softmax(output), 1)
    
    # Note that if confidence value is greater than a p_tar value, we terminate the dnn inference and returns the output
    # This also happens in the last exit
    if (conf.item() >= p_tar):
      return output, conf.item(), infered_class, self.n_branches
    else:

      # If any exit can reach the p_tar value, the output is give by the more confidence output.
      # If evaluation, it returns max(output), max(conf) and the number of the early exit.

      conf_list.append(conf.item())
      class_list.append(infered_class)
      output_list.append(output)
      max_conf = np.argmax(conf_list)
      return output_list[max_conf], conf_list[max_conf], class_list[max_conf], self.n_branches

  def forward_inference_calib_branches(self, x, p_tar, temp_branches):
    """
    This method is used to experiment of early-exit DNNs with calibration in all the branches.
    """

    output_list, conf_list, class_list  = [], [], []

    for i, exitBlock in enumerate(self.exits):
      x = self.stages[i](x)

      output_branch = exitBlock(x)
      output_branch = self.temperature_scale_branches(output_branch, temp_branches, i)
      conf, infered_class = torch.max(self.softmax(output_branch), 1)

      # Note that if confidence value is greater than a p_tar value, we terminate the dnn inference and returns the output
      if (conf.item() >= p_tar):
        return output_branch, conf.item(), infered_class, i

      else:
        output_list.append(output_branch)
        conf_list.append(conf.item())
        class_list.append(infered_class)

    x = self.stages[-1](x)
    
    x = torch.flatten(x, 1)

    output = self.classifier(x)
    output = self.temperature_scale_branches(output, temp_branches, -1)
    
    conf, infered_class = torch.max(self.softmax(output), 1)
    
    # Note that if confidence value is greater than a p_tar value, we terminate the dnn inference and returns the output
    # This also happens in the last exit
    if (conf.item() >= p_tar):
      return output, conf.item(), infered_class, self.n_branches
    else:

      # If any exit can reach the p_tar value, the output is give by the more confidence output.
      # If evaluation, it returns max(output), max(conf) and the number of the early exit.

      conf_list.append(conf.item())
      class_list.append(infered_class)
      output_list.append(output)
      max_conf = np.argmax(conf_list)
      return output_list[max_conf], conf_list[max_conf], class_list[max_conf], self.n_branches




  def forward_inference_test(self, x, p_tar=0.5):
    """
    This method is used to experiment of early-exit DNNs.
    """
    output_list, conf_list, class_list  = [], [], []
    n_exits = self.n_branches + 1
    exit_branches = np.zeros(n_exits)
    wasClassified = False

    for i, exitBlock in enumerate(self.exits):
      x = self.stages[i](x)

      output_branch = exitBlock(x)
      conf_branch, infered_class_branch = torch.max(self.softmax(output_branch), 1)
      conf_list.append(conf_branch.item()), class_list.append(infered_class_branch)

      if (conf_branch.item() >= p_tar):
        exit_branches[i] = 1

        if (not wasClassified):
          actual_exit_branch = i
          actual_conf = conf_branch.item()
          actual_inferred_class = infered_class_branch
          wasClassified = True

    x = self.stages[-1](x)
    
    x = torch.flatten(x, 1)

    output = self.classifier(x)
    conf, infered_class = torch.max(self.softmax(output), 1)
    conf_list.append(conf.item()), class_list.append(infered_class)

    exit_branches[-1] = 1

    if (conf.item() <  p_tar):
      max_conf = np.argmax(conf_list)
      conf_list[-1] = conf_list[max_conf]
      class_list[-1] = class_list[max_conf]

    if (not wasClassified):
      actual_exit_branch = self.n_branches
      actual_conf = conf_list[-1]
      actual_inferred_class = class_list[-1]

    return actual_conf, actual_inferred_class, actual_exit_branch, conf_list, class_list, exit_branches


  def forwardEval(self, x, p_tar):
    """
    This method is used to train the early-exit DNN model
    """
    output_list, conf_list, class_list  = [], [], []

    for i, exitBlock in enumerate(self.exits):
      x = self.stages[i](x)

      output_branch = exitBlock(x)
      conf, infered_class = torch.max(self.softmax(output_branch), 1)

      # Note that if confidence value is greater than a p_tar value, we terminate the dnn inference and returns the output
      if (conf.item() >= p_tar):
        return output_branch, conf.item(), infered_class, i

      else:
        output_list.append(output_branch)
        conf_list.append(conf.item())
        class_list.append(infered_class)

    x = self.stages[-1](x)
    
    x = torch.flatten(x, 1)

    output = self.classifier(x)
    conf, infered_class = torch.max(self.softmax(output), 1)
    
    # Note that if confidence value is greater than a p_tar value, we terminate the dnn inference and returns the output
    # This also happens in the last exit
    if (conf.item() >= p_tar):
      return output, conf.item(), infered_class, self.n_branches
    else:

      # If any exit can reach the p_tar value, the output is give by the more confidence output.
      # If evaluation, it returns max(output), max(conf) and the number of the early exit.

      conf_list.append(conf.item())
      class_list.append(infered_class)
      output_list.append(output)
      max_conf = np.argmax(conf_list)
      return output_list[max_conf], conf_list[max_conf], class_list[max_conf], self.n_branches


  def forward(self, x, p_tar=0.5, training=True):
    """
    This implementation supposes that, during training, this method can receive a batch containing multiple images.
    However, during evaluation, this method supposes an only image.
    """
    if (training):
      return self.forwardTrain(x)
    else:
      return self.forwardEval(x, p_tar)

In [14]:
class BranchesModelWithTemperature(nn.Module):
  def __init__(self, model, n_branches, distortion_list, device, save_path, lr=0.01, max_iter=50):
    super(BranchesModelWithTemperature, self).__init__()
    """
    This method calibrates a early-exit DNN. The calibration goal is to turn the classification confidencer closer to the real model's accuracy.
    In this work, we apply the calibration method called Temperature Scaling.
    The paper below explains in detail: https://arxiv.org/pdf/1706.04599.pdf

    Here, we follow two approaches:
    * we find a temperature parameter for each side branch
    * we find a temperature parameter for the entire early-exit DNN model.

    """
    self.model = model            #this receives the architecture model. It is important to notice this models has already trained. 
    self.n_branches = n_branches  #the number of side branches or early exits.
    self.n_exits = self.n_branches + 1 
    self.device = device               
    self.lr = lr                  # defines the learning rate of the calibration process.
    self.max_iter = max_iter      #defines the number of iteractions to train the calibration process
    self.save_path = save_path    # indicates the path to save the temperature in the temperature scaling method
    
    # This line initiates a parameters list of the temperature 
    self.temperature_branches = [nn.Parameter(torch.ones(1)*1.5) for i in range(self.n_exits)]
    self.softmax = nn.Softmax(dim=1)
    
    # This line initiates a single temperature parameter for the entire early-exit DNN model
    self.temperature_overall = nn.Parameter(torch.ones(1)*1.5)

  def forward_branches(self, input, p_tar):
    return self.model.forward_inference_calib_branches(input, p_tar, self.temperature_branches)

  def forward_overall(self, input, p_tar):
     return self.model.forward_inference_calib_overall(input, p_tar, self.temperature_overall)

  def temperature_scale_overall(self, logits):
    temperature = self.temperature_overall.unsqueeze(1).expand(logits.size(0), logits.size(1)).to(self.device)
    return logits / temperature
    
  def temperature_scale_branches(self, logits, i):
    temperature = self.temperature_branches[i].unsqueeze(1).expand(logits.size(0), logits.size(1)).to(self.device)
    return logits / temperature
  
  def save_temperature_branches(self, p_tar, before_temperature_nll_list, after_temperature_nll_list):

    temperature_dict = {}

    df = pd.read_csv(self.save_path) if (os.path.exists(self.save_path)) else pd.DataFrame()
    
    for i in range(self.n_exits):
      temperature_dict.update({"p_tar": p_tar, "temperature_branch_%s"%(i+1): (self.temperature_branches[i].data).cpu().numpy().item(),
                               "before_nll_branch_%s"%(i+1): before_temperature_nll_list[i], 
                               "after_nll_branch_%s"%(i+1): after_temperature_nll_list[i]})
    
    df = df.append(pd.Series(temperature_dict), ignore_index=True)
    df.to_csv(self.save_path)

  def save_temperature_overall(self, p_tar, before_temperature_nll, after_temperature_nll):
    """
    This method saves the temperature in an csv file in self.save_path
    This saves: 
    p_tar: which means the threshold
    before_temperature_nll: the error before the calibration  
    after_temperature_nll: the error after the calibration
    temperature parameter:
                 
    """
    temperature_dict = {}

    df = pd.read_csv(self.save_path) if (os.path.exists(self.save_path)) else pd.DataFrame()
    
    temperature_dict.update({"p_tar": p_tar, "temperature": (self.temperature_overall.data).cpu().numpy().item(),
                             "before_nll": before_temperature_nll, "after_nll": after_temperature_nll})
    
    df = df.append(pd.Series(temperature_dict), ignore_index=True)
    df.to_csv(self.save_path)

  def calibrate_overall(self, val_loader, p_tar):
    """
    This method calibrates the entire model. In other words, this method finds a singles temperature parameter 
    for the entire early-exit DNN model
    """
    nll_criterion = nn.CrossEntropyLoss().to(self.device)
    
    logits_list = []
    labels_list = []
    exit_branch_list = np.zeros(self.n_exits)

    self.model.eval()
    with torch.no_grad():
      for i, (data, target) in enumerate(val_loader, 1):
        if(i%1000==0):
          print("Calibration Batch: %s/%s"%(i, len(val_loader)))
          
        data, target = data.to(self.device), target.to(self.device)
        
        logits, conf, infer_class, exit_branch = self.model(data, p_tar, training=False)

        logits_list.append(logits)
        labels_list.append(target)
        exit_branch_list[exit_branch] += 1

    optimizer = optim.LBFGS([self.temperature_overall], lr=self.lr, max_iter=50)

    logits_list = torch.cat(logits_list).to(self.device)
    labels_list = torch.cat(labels_list).to(self.device)

    before_temperature_nll = nll_criterion(logits_list, labels_list).item()

    def eval():
      loss = nll_criterion(self.temperature_scale_overall(logits_list), labels_list)
      loss.backward()
      return loss
      
    optimizer.step(eval)

    after_temperature_nll = nll_criterion(self.temperature_scale_overall(logits_list), labels_list).item()
    print("Before NLL: %s, After NLL: %s"%(before_temperature_nll, after_temperature_nll))
    print("Temp %s"%(self.temperature_overall.item()))
    # This saves the parameter to save the temperature parameter
    self.save_temperature_overall(p_tar, before_temperature_nll, after_temperature_nll)


  def calibrate_branches(self, val_loader, p_tar):
    """
    This method calibrates for each side branch. In other words, this method finds a temperature parameter 
    for each side branch of the early-exit DNN model.
    """

    nll_criterion = nn.CrossEntropyLoss().to(self.device)
    
    logits_list = [[] for i in range(self.n_exits)]
    labels_list = [[] for i in range(self.n_exits)]
    before_temperature_nll_list, after_temperature_nll_list = [], []

    self.model.eval()
    with torch.no_grad():
      for i, (data, target) in enumerate(val_loader, 1):
        if(i%1000==0):
          print("Calibration Batch: %s/%s"%(i, len(val_loader)))
          
        data, target = data.to(self.device), target.to(self.device)
        
        logits, conf, infer_class, exit_branch = self.model(data, p_tar, training=False)

        logits_list[exit_branch].append(logits)
        labels_list[exit_branch].append(target)


    for i in range(self.n_exits):
      if (len(logits_list[i]) == 0):
        continue
      optimizer = optim.LBFGS([self.temperature_branches[i]], lr=self.lr, max_iter=50)

      logit_branch = torch.cat(logits_list[i]).to(self.device)
      label_branch = torch.cat(labels_list[i]).to(self.device)

      before_temperature_nll = nll_criterion(logit_branch, label_branch).item()
      before_temperature_nll_list.append(before_temperature_nll)

      def eval(i):
        loss = nll_criterion(self.temperature_scale_branches(logit_branch, i), label_branch)
        loss.backward()
        return loss
      
      eval_branch = functools.partial(eval, i=i)
      optimizer.step(eval_branch)

      after_temperature_nll = nll_criterion(self.temperature_scale_branches(logit_branch, i), label_branch).item()
      after_temperature_nll_list.append(after_temperature_nll)
      print("Branch: %s, Before NLL: %s, After NLL: %s"%(i+1, before_temperature_nll, after_temperature_nll))
      print("Temp %s: %s"%(i, self.temperature_branches[i].item()))
    
    # This saves the parameter to save the temperature parameter for each side branch
    self.save_temperature_branches(p_tar, before_temperature_nll_list, after_temperature_nll_list)

    return self

In [19]:
def experiement_early_exit_inference(model, test_loader, p_tar, n_branches, device, model_type):

  n_exits = n_branches + 1
  conf_list = []
  nr_branch_exit, correct_branches = np.zeros(n_exits), np.zeros(n_exits)

  model.eval()
  with torch.no_grad():
    for i, (data, target) in enumerate(test_loader, 1):
      if (i % 1000 == 0):
        print("Batch: %s"%(i))
      
      data, target = data.to(device), target.float().to(device)

      if (model_type == "calib_overall"):
        output, conf, infered_class, branch_exit = model.forward_overall(data, p_tar)

      elif (model_type == "calib_branches"):
        output, conf, infered_class, branch_exit = model.forward_branches(data, p_tar)
      
      else:
        output, conf, infered_class, branch_exit = model(data, p_tar, training=False)
      
      
      conf_list.append(conf)
      
      nr_branch_exit[branch_exit] += 1
      isCorrect = infered_class.eq(target.view_as(infered_class)).sum().item()
      correct_branches[branch_exit] += isCorrect
      

      del data, target
      torch.cuda.empty_cache()
 

  acc_branches = 100*(correct_branches/nr_branch_exit)
  acc_avg = 100*(sum(correct_branches)/sum(nr_branch_exit))
  conf_avg = round(np.mean(conf_list), 3)

  nr_samples = sum(nr_branch_exit)
  nr_total_samples = nr_samples

  result = {"p_tar": p_tar, "avg_acc": acc_avg, "avg_conf": conf_avg}

  for i, (acc_branch, nr_branch) in enumerate(zip(acc_branches, nr_branch_exit), 1):
    result.update({"acc_branch_%s"%(i): acc_branch})
    result.update({"nr_exit_branch_%s"%(i): 100*(nr_branch/nr_samples)})
    if (i < (n_branches+1) ):
      result.update({"edge_exit_rate_branch_%s"%(i): 100*(sum(nr_branch_exit[:i])/sum(nr_branch_exit))})    
    nr_samples -= nr_branch
  
  return result


def calibrating_early_exit_dnn(model, val_loader, p_tar, n_branches, device, savePathTemperature):

  print("Calibrating ...")

  overall_calibrated_model = BranchesModelWithTemperature(model, n_branches, val_loader, device, savePathTemperature["calib_overall"])
  overall_calibrated_model.calibrate_overall(val_loader, p_tar)
    
  branches_calibrated_model = BranchesModelWithTemperature(model, n_branches, val_loader, device, savePathTemperature["calib_branches"])
  branches_calibrated_model.calibrate_branches(val_loader, p_tar)

  return overall_calibrated_model, branches_calibrated_model 


def save_results(result, save_path):
  df_result = pd.read_csv(save_path) if (os.path.exists(save_path)) else pd.DataFrame()
  df_result = df_result.append(pd.Series(result), ignore_index=True)
  df_result.to_csv(save_path)


def save_all_results(no_calib_result, calib_overall_result, calib_branches_result, save_path_dict):
    save_results(no_calib_result, save_path_dict["no_calib"])
    save_results(calib_overall_result, save_path_dict["calib_overall"])
    save_results(calib_branches_result, save_path_dict["calib_branches"])

def exp_prob_edge_inference(model, test_loader, val_loader, threshold_list, n_branches, device, save_results_dict, save_temp_dict):

  df_result = pd.DataFrame()
  df_result_samples = pd.DataFrame()

  for p_tar in threshold_list:
    print("P_tar: %s"%(p_tar))
        
    overall_calib_model, branches_calib_model = calibrating_early_exit_dnn(model, val_loader, p_tar, n_branches, device, save_temp_dict)

    no_calib_result = experiement_early_exit_inference(model, test_loader, p_tar, n_branches, device, model_type="no_calib")
    
    calib_overall_result = experiement_early_exit_inference(overall_calib_model, test_loader, p_tar, 
                                                                                          n_branches, device, model_type="calib_overall")
    
    calib_branches_result = experiement_early_exit_inference(branches_calib_model, test_loader, p_tar, 
                                                                                            n_branches, device, model_type="calib_branches")

    save_all_results(no_calib_result, calib_overall_result, calib_branches_result, save_results_dict)
    


In [16]:
model_name = "mobilenet"
dataset_name = "caltech256"
model_id = 1
img_dim = 300
input_dim = 300
batch_size_train, batch_size_test = 64, 1
split_ratio = 0.1
save_idx = False


root_dir = "./drive/MyDrive/early_exit_test" #diretório-raiz
dataset_path = "./drive/MyDrive/undistorted_datasets/Caltech256/256_ObjectCategories" #caminho em que está salvo o dataset Caltec 256

save_root_path = os.path.join(root_dir, dataset_name, model_name)
if (not os.path.exists(save_root_path)):
  os.makedirs(save_root_path)



dataset = LoadDataset(img_dim, batch_size_train, batch_size_test, save_idx, model_id)
_, val_loader, test_loader = dataset.caltech_256(dataset_path, split_ratio, save_root_path)

  cpuset_checked))


In [17]:
n_classes = 258
pretrained = True
n_branches = 5
n_exits = n_branches + 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_shape = (3, input_dim, input_dim)
distribution = "linear"
exit_type = "bnpool"

# this line indicates the path to trained early-exit DNN model. You must change it to yours trained model
model_path = "./drive/MyDrive/project_quality_magazine/caltech256/mobilenet/models/pristine_model_mobilenet_caltech256_3_5_b.pth"

early_exit_model = Early_Exit_DNN(model_name, n_classes, pretrained, n_branches, input_shape, exit_type, device, distribution=distribution)
early_exit_model = early_exit_model.to(device)
early_exit_model.exits.to(device)

# this line loads the trained model to the early_exit_model.
early_exit_model = load_early_exit_dnn_model(early_exit_model, model_path, device)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


  0%|          | 0.00/13.6M [00:00<?, ?B/s]

Input size: (1, 3, 300, 300)
589,580,360 FLOPs or approx. 0.59 GFLOPs
Input size: (1, 3, 300, 300)
20,880,000 FLOPs or approx. 0.02 GFLOPs
Input size: (1, 3, 300, 300)
41,040,000 FLOPs or approx. 0.04 GFLOPs
Input size: (1, 3, 300, 300)
99,090,000 FLOPs or approx. 0.10 GFLOPs
Input size: (1, 3, 300, 300)
149,040,000 FLOPs or approx. 0.15 GFLOPs
Input size: (1, 3, 300, 300)
179,133,664 FLOPs or approx. 0.18 GFLOPs
Input size: (1, 3, 300, 300)
200,666,592 FLOPs or approx. 0.20 GFLOPs
Input size: (1, 3, 300, 300)
222,199,520 FLOPs or approx. 0.22 GFLOPs
Input size: (1, 3, 300, 300)
236,870,560 FLOPs or approx. 0.24 GFLOPs
Input size: (1, 3, 300, 300)
256,508,960 FLOPs or approx. 0.26 GFLOPs
Input size: (1, 3, 300, 300)
276,147,360 FLOPs or approx. 0.28 GFLOPs
Input size: (1, 3, 300, 300)
295,785,760 FLOPs or approx. 0.30 GFLOPs
Input size: (1, 3, 300, 300)
319,837,024 FLOPs or approx. 0.32 GFLOPs
Input size: (1, 3, 300, 300)
362,602,528 FLOPs or approx. 0.36 GFLOPs
Input size: (1, 3, 300,

In [20]:
result_path = os.path.join(save_root_path, "results")
temp_path = os.path.join(save_root_path, "temperature")

if (not os.path.exists(result_path)):
  os.makedirs(result_path)
if (not os.path.exists(temp_path)):
  os.makedirs(temp_path)

"""
Esse primeiro bloco define os caminhos para armazenar os resultados coletados depois avaliar o conjunto de teste. 
Por exemplo, nesses arquivos, serão armazenado resultados como acurácia do modelo, dos ramos laterais, o número total de amostras classificadas
em cada ramo lateral.  
Esses resultados são armazenados em arquivos .csv, cujos caminhos estão definidos no bloco abaixo. 
Definimos três arquivos para armazenar os resultados. 
save_no_calib_path: armazena os resultados do modelo sem calibração
save_calib_overall_path: armazena os resultados do modelo com calibração que encontra um parâmetros temperature para o modelo inteiro
save_calib_branches_path: armazena os resultados do modelo com calibração que encontra um parâmetro para cada ramo lateral. 
"""

save_no_calib_path =  os.path.join(result_path, "exp_no_calib_prob_edge_branches_%s_%s.csv"%(n_branches, model_id)) 
save_calib_overall_path =  os.path.join(result_path, "exp_calib_overall_prob_edge_branches_%s_%s.csv"%(n_branches, model_id)) 
save_calib_branches_path =  os.path.join(result_path, "exp_calib_branches_prob_edge_branches_%s_%s.csv"%(n_branches, model_id)) 


"""
Os dois próximos caminhos determinam onde salvar os parâmetros da calibração 
"""
saveTemperatureBranchesPath = os.path.join(temp_path, "branches_temp_scaling_branches_%s_%d.csv"%(n_branches, model_id))
saveTemperatureOverallPath = os.path.join(temp_path, "branches_temp_scaling_overall_%s_%d.csv"%(n_branches, model_id))


save_results_dict = {"no_calib": save_no_calib_path, 
                     "calib_overall": save_calib_overall_path, 
                     "calib_branches": save_calib_branches_path}



save_temp_dict = {"calib_overall":saveTemperatureOverallPath, "calib_branches": saveTemperatureBranchesPath}

threshold_list = [0.7, 0.75, 0.8, 0.85, 0.9]

exp_prob_edge_inference(early_exit_model, test_loader, val_loader, threshold_list, n_branches, device, save_results_dict, save_temp_dict)

P_tar: 0.7
Calibrating ...


  cpuset_checked))


Calibration Batch: 1000/2754
Calibration Batch: 2000/2754
Before NLL: 1.1887091398239136, After NLL: 1.2051316499710083
Temp 0.9741997122764587
Calibration Batch: 1000/2754
Calibration Batch: 2000/2754
Branch: 1, Before NLL: 3.897165298461914, After NLL: 3.237544059753418
Temp 0: 1.6198656558990479
Branch: 2, Before NLL: 1.8751261234283447, After NLL: 1.749464988708496
Temp 1: 1.2296142578125
Branch: 3, Before NLL: 0.9309037327766418, After NLL: 0.9208666086196899
Temp 2: 1.0436605215072632
Branch: 4, Before NLL: 0.7223101854324341, After NLL: 0.7164681553840637
Temp 3: 1.0305308103561401
Branch: 5, Before NLL: 0.8406816124916077, After NLL: 0.7830142974853516
Temp 4: 1.2107363939285278
Branch: 6, Before NLL: 2.641573667526245, After NLL: 2.1959619522094727
Temp 5: 1.7777103185653687
Batch: 1000
Batch: 2000
Batch: 3000
Batch: 1000
Batch: 2000
Batch: 3000
Batch: 1000
Batch: 2000
Batch: 3000
P_tar: 0.75
Calibrating ...


KeyboardInterrupt: ignored