In [1]:
import os
import time
import zipfile
import math
import random
import numpy as np
import multiprocessing as mp

from pathlib import Path
from six.moves import urllib
from copy import deepcopy

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split
import torch.backends.cudnn as cudnn

import torchvision
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid
from torchvision.transforms.functional import InterpolationMode

from matplotlib import pyplot as plt

from ofa.utils import AverageMeter, accuracy
from torchprofile.profile import profile_macs

# DATALOADERS

In [2]:
def get_loaders(train_dataset, validation_dataset, test_dataset, batch_size):
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, num_workers=8, shuffle=True, pin_memory=True)
    validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size, num_workers=8, shuffle=False, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, num_workers=8, shuffle=False, pin_memory=True)
    return train_loader, validation_loader, test_loader

## dataset generators

### cifar10

In [3]:
def get_cifar10(root="./datasets", img_size=32):
    
    n_classes = 10
    img_channels = 3

    transform = transforms.Compose([
        transforms.Resize(size=(img_size,img_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768))
    ])
    
    # create train and test set
    train_val_dataset = torchvision.datasets.CIFAR10(root=root, train=True,download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root=root, train=False,download=True, transform=transform) 

    #generate the validtion set
    torch.manual_seed(420)
    val_size = 5000 #10%
    train_size = len(train_val_dataset) - val_size
    train_dataset, validation_dataset = random_split(train_val_dataset, [train_size, val_size])

    classes_weights=None

    return  n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights

### cifar100

In [4]:
def get_cifar100(root="./datasets", img_size=32):

    n_classes = 100
    img_channels = 3

    transform = transforms.Compose([
        transforms.Resize(size=(img_size,img_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768))
    ])
    
    # create train and test set
    train_val_dataset = torchvision.datasets.CIFAR100(root=root, train=True,download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR100(root=root, train=False,download=True, transform=transform) 

    #generate the validtion set
    torch.manual_seed(420)
    val_size = 5000 #10%
    train_size = len(train_val_dataset) - val_size
    train_dataset, validation_dataset = random_split(train_val_dataset, [train_size, val_size])

    classes_weights=None

    return  n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights

### tiny imagenet

In [5]:
def rename_subfolders(root_dir, old_new_names_dict):
    old_names = []
    for it in os.scandir(root_dir):
        if it.is_dir():
            old_name = os.path.split(it.path)[-1]
            old_names.append(old_name)

    for old_name in old_names:
        old_path = os.path.join(root_dir, old_name)
        new_name = old_new_names_dict[old_name]
        new_path = os.path.join(root_dir, new_name)
        os.rename(old_path, new_path)

In [6]:
class TinyImagenetDownloader:
    
    def __init__(self,dataset_path):
        
        self.save_path=dataset_path
        
        self.train_path=os.path.join(dataset_path,"train")
        
        self.valid_path=os.path.join(dataset_path,"val")
        self.valid_path=os.path.join(self.valid_path,"images")
        
        self.data_url="http://cs231n.stanford.edu/tiny-imagenet-200.zip"
        
        self.download_dataset()
    
    '''
         This method is responsible for separating validation images into separate sub folders
         modified from https://github.com/DennisHanyuanXu/Tiny-ImageNet/blob/f08f6e69375a8142ecdff76afd9457620f399f48/src/data_prep.py
         '''
    def create_val_img_folder(self):

        dataset_dir = self.save_path
        val_dir = os.path.join(dataset_dir, 'val')
        img_dir = os.path.join(val_dir, 'images')

        fp = open(os.path.join(val_dir, 'val_annotations.txt'), 'r')
        data = fp.readlines()
        val_img_dict = {}
        for line in data:
            words = line.split('\t')
            val_img_dict[words[0]] = words[1]
        fp.close()

        # Create folder if not present and move images into proper folders
        for img, folder in val_img_dict.items():
            newpath = (os.path.join(img_dir, folder))
            if not os.path.exists(newpath):
                os.makedirs(newpath)
            if os.path.exists(os.path.join(img_dir, img)):
                os.rename(os.path.join(img_dir, img), os.path.join(newpath, img))

    '''
    This method is responsible for converting the class identifier to the class name
    modified from https://github.com/DennisHanyuanXu/Tiny-ImageNet/blob/f08f6e69375a8142ecdff76afd9457620f399f48/src/data_prep.py
    '''
    def get_class_name(self):
        class_to_name = dict()
        fp = open(os.path.join(self.save_path, 'words.txt'), 'r')
        data = fp.readlines()
        for line in data:
            words = line.strip('\n').split('\t')
            class_to_name[words[0]] = words[1].split(',')[0]
        fp.close()
        return class_to_name

    def rename_folders(self):
        class_to_name = self.get_class_name()
        rename_subfolders(self.train_path, class_to_name)
        rename_subfolders(self.valid_path, class_to_name)

    def download_dataset(self):

        # check if the unzipped folder already exists ...
        if not os.path.exists(self.save_path):

            # if not create parent directory
            parent_dir = Path(self.save_path).parent
            os.makedirs(parent_dir)

            # download there the zip file
            print('Downloading %s' % self.data_url)
            data_zip = urllib.request.urlopen(self.data_url)
            # download .zip file
            zip_path = os.path.join(parent_dir, "tiny-imagenet-200.zip")
            with open(zip_path, 'wb') as f:
                f.write(data_zip.read())
            print("Download complete")

            # unzip it
            print("Unzipping the Dataset, please wait")
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(parent_dir)
            print("Unzip complete")

            # move validation images into folders
            self.create_val_img_folder()

            # rename the validation folders
            self.rename_folders()

            # delete the zip file
            # os.remove(zip_path)
        else:
            print("already there")

In [7]:
def get_tiny_imagenet(root="./datasets", img_size=64):

    n_classes = 200
    img_channels = 3

    dataset_path = os.path.join(root,"tiny-imagenet","tiny-imagenet-200")
    TinyImagenetDownloader(dataset_path)

    # define transformations
    train_transforms=[
        transforms.Resize(size=(img_size,img_size),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2770, 0.2691, 0.2821])
    ]
    train_transforms=transforms.Compose(train_transforms)

    valid_transforms=[
        transforms.Resize(size=(img_size,img_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2770, 0.2691, 0.2821])
    ]
    valid_transforms=transforms.Compose(valid_transforms)

    train_path=os.path.join(dataset_path,"train")
    test_path=os.path.join(dataset_path,"val","images")

    # generate train and test sets
    train_val_dataset= torchvision.datasets.ImageFolder(train_path,train_transforms)
    test_dataset= torchvision.datasets.ImageFolder(test_path,valid_transforms)

    #generate the validation set
    torch.manual_seed(420)
    val_size = 10000 #10%
    train_size = len(train_val_dataset) - val_size
    train_dataset, validation_dataset = random_split(train_val_dataset, [train_size, val_size])

    classes_weights=None

    return  n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights

### fashionMNIST

In [8]:
def get_fashionMNIST(root="./datasets", img_size=28):

    n_classes = 10
    img_channels = 1

    transform = transforms.Compose([
        transforms.Resize(size=(img_size,img_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5)
    ])

    # create train and test set
    train_val_dataset = torchvision.datasets.FashionMNIST(root=root, train=True,download=True, transform=transform)
    test_dataset = torchvision.datasets.FashionMNIST(root=root, train=False,download=True, transform=transform)

    #generate the validation set
    torch.manual_seed(420)
    val_size = 9000 #15%
    train_size = len(train_val_dataset) - val_size
    train_dataset, validation_dataset = random_split(train_val_dataset, [train_size, val_size])

    classes_weights=None

    return  n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights

### eurosat

In [9]:
def get_eurosat(root="./datasets", img_size=64):

    n_classes = 10
    img_channels = 3

    transform = transforms.Compose([
        transforms.Resize(size=(img_size,img_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
    ])

    # create the dataset
    dataset = torchvision.datasets.EuroSAT(root=root, download=True, transform=transform)

    # partition the dataset in train, validation and test
    torch.manual_seed(420)
    val_size = 4000 #15%
    test_size = 5500 #20%
    train_size = len(dataset) - val_size - test_size
    train_dataset, validation_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    classes_weights = torch.Tensor([0.9, 0.9, 0.9, 1.08, 1.08, 1.35, 1.08, 0.9, 1.08, 0.9])
    #classes_weights = None

    return  n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights

### GTSRB

In [10]:
def get_gtsrb(root="./datasets", img_size=(1360,1024)):

    n_classes = 43
    img_channels = 3

    transform = transforms.Compose([
        transforms.Resize(size=(img_size,img_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.3337, 0.3064, 0.3171), ( 0.2672, 0.2564, 0.2629))

    ])

    # create the dataset
    train_val_dataset = torchvision.datasets.GTSRB(root=root, split="train", download=True, transform=transform)
    test_dataset = torchvision.datasets.GTSRB(root=root, split="test", download=True, transform=transform)

    #generate the validation set
    torch.manual_seed(420)
    val_size = 6000 #15%
    train_size = len(train_val_dataset) - val_size
    train_dataset, validation_dataset = random_split(train_val_dataset, [train_size, val_size])

    classes_weights = torch.Tensor([
        4.130232558139535,
        0.4130232558139535,
        0.4130232558139535,
        0.6453488372093024,
        0.4693446088794926,
        0.49169435215946844,
        2.0651162790697675,
        0.6453488372093024,
        0.6453488372093024,
        0.6257928118393234,
        0.45891472868217054,
        0.6883720930232559,
        0.4393864423552697,
        0.43023255813953487,
        1.1472868217054264,
        1.4750830564784052,
        2.0651162790697675,
        0.826046511627907,
        0.7648578811369509,
        4.130232558139535,
        2.5813953488372094,
        2.5813953488372094,
        2.294573643410853,
        1.7209302325581395,
        3.441860465116279,
        0.6073871409028728,
        1.4750830564784052,
        3.441860465116279,
        1.7209302325581395,
        3.441860465116279,
        2.0651162790697675,
        1.1472868217054264,
        3.441860465116279,
        1.2906976744186047,
        2.0651162790697675,
        0.7648578811369509,
        2.294573643410853,
        4.130232558139535,
        0.448938321536906,
        2.9501661129568104,
        2.5813953488372094,
        3.441860465116279,
        3.441860465116279
    ])
    #classes_weights = None

    return  n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights

## show samples

In [11]:
# show a batch of images with their respective labels
def show_samples(loader, batch_size):
    rows= batch_size//8
    columns = batch_size//rows

    for images, labels in loader:
        print('images.shape:', images.shape)
        plt.figure(figsize=(rows,columns))
        plt.axis('off')
        plt.imshow(make_grid(images, nrow=rows).permute((1, 2, 0)))
        print(labels)
        break

## get datasets

In [12]:
def get_dataloaders(dataset, dataset_save_path, img_size, batch_size):

    if dataset == "cifar10":
        n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights = get_cifar10(dataset_save_path, img_size)

    elif dataset == "cifar100":
        n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights = get_cifar100(dataset_save_path, img_size)

    elif dataset == "tiny_imagenet":
        n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights = get_tiny_imagenet(dataset_save_path, img_size)

    elif dataset == "fashionMNIST":
        n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights = get_fashionMNIST(dataset_save_path, img_size)

    elif dataset == "eurosatW":
        n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights = get_eurosat(dataset_save_path, img_size)

    elif dataset == "gtsrbW":
        n_classes, img_channels, train_dataset, validation_dataset, test_dataset, classes_weights = get_gtsrb(dataset_save_path, img_size)

    else:
        raise ValueError("dataset not implemeneted")

    print("training images:", len(train_dataset))
    print("validation images:", len(validation_dataset))
    print("test images:", len(test_dataset))

    train_loader, validation_loader, test_loader = get_loaders(train_dataset, validation_dataset, test_dataset, batch_size)
    #show_samples(train_loader, batch_size)
    
    return train_loader, validation_loader, test_loader, n_classes, img_channels, classes_weights

# NETWORK

## networks generators

### my network

In [13]:
class MySENetwork(nn.Module):
    
    def __init__(self, input, stages, exit):

        super(MySENetwork,self).__init__()

        self.input = input
        self.stages = stages
        self.exit = exit

    def forward(self,x):
        
        # pass through input layer
        x=self.input(x)
        
        #pass through stages
        for stage in self.stages:
            x=stage(x)
        
        # pass through the only exit
        x=self.exit(x)
        
        # return result
        return x

In [14]:
class MyMENetwork(nn.Module):
    
    def __init__(self, input, stages, exits, exits_positions):

        super(MyMENetwork,self).__init__()

        assert(len(exits_positions)==len(exits))
        self.input = input
        self.stages = stages
        self.exits = exits
        self.exits_positions = exits_positions

    def forward(self, x):

        # pass through input layer
        x=self.input(x)
        
        #pass through stages
        intermediate_results = []
        for stage in self.stages:
            x=stage(x)
            intermediate_results.append(x)

        results = []        
        useful_intermediate_results = [intermediate_results[exit_pos-1] for exit_pos in self.exits_positions]
        
        for x, exit in zip(useful_intermediate_results, self.exits):
            x = exit(x)
            results.append(x)

        return results

    def extract_subnetwork(self, selected_exits):
    
        selected_exits = sorted(selected_exits) # order exits numbers
        selected_exits = list(set(selected_exits))  # remove duplicates
        
        assert min(selected_exits)>=1   # exit number must be at least one
        assert max(selected_exits)<=len(self.stages)    # exit number cannot be more than the number of stages
        assert isinstance(selected_exits,(list,tuple))  # check is a list
        assert (set(selected_exits)).issubset(set(self.exits_positions))    # check that selected exit are among the ones present
        
        # keep input layer
        input = deepcopy(self.input)
        
        # keep stages only up to last selected exit
        last_stage = max(selected_exits) 
        stages = deepcopy(self.stages)
        stages = stages[:last_stage]

        # keep only selected exits
        exits = nn.ModuleList()
        for pos in selected_exits:
            exits.append(deepcopy(self.exits[pos-1]))

        if len(exits)==1:
            return MySENetwork(input,stages,exits[0])
        else:
            return MyMENetwork(input,stages,exits,selected_exits)

### resnet

In [15]:
class NewResnet50Exit(nn.Module):

    def __init__(self, features, n_classes):
        
        super(NewResnet50Exit,self).__init__()
        
        self.fc = nn.Linear(features, n_classes)

    def forward(self,x):
        x = F.adaptive_avg_pool2d(x, (1,1))
        x = torch.flatten(x,1)
        x = self.fc(x)

        return x

In [16]:
def get_resnet50(is_single_exit, pretrained, n_classes, img_channels):
    
    model = models.resnet50(pretrained)

    input = nn.Sequential(
        nn.Conv2d(img_channels, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False),
        nn.Sequential(*list(model.children())[1:4])
    )
    
    stage1 = list(model.children())[4]
    stage2 = list(model.children())[5]
    stage3 = list(model.children())[6]
    stage4 = list(model.children())[7]
    stages = nn.ModuleList([stage1,stage2,stage3,stage4])
    
    exit4 = NewResnet50Exit(2048, n_classes)
    if is_single_exit:
        network = MySENetwork(input, stages, exit4)
    else: 
        exit1 = NewResnet50Exit(256, n_classes)
        exit2 = NewResnet50Exit(512, n_classes)
        exit3 = NewResnet50Exit(1024, n_classes)
        exits = nn.ModuleList([exit1,exit2,exit3,exit4])
        network = MyMENetwork(input, stages, exits,[1,2,3,4])

    n_exits = 4
    
    return network, n_exits

### vgg

In [17]:
class NewVgg16Exit(nn.Module):

    def __init__(self, features, n_classes):
            
            super(NewVgg16Exit,self).__init__()
            
            self.fc = nn.Linear(features, n_classes)

    def forward(self,x):
        x = F.adaptive_avg_pool2d(x, (1,1))
        x = torch.flatten(x,1)
        x = self.fc(x)

        return x

In [18]:
def get_vgg16(is_single_exit, pretrained, n_classes, img_channels):
    
    model = models.vgg16(pretrained)
    
    input = nn.Conv2d(img_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    stage1 = nn.Sequential(*list(model.features.children())[1:10])
    stage2 = nn.Sequential(*list(model.features.children())[10:17])
    stage3 = nn.Sequential(*list(model.features.children())[17:24])
    stage4 = nn.Sequential(*list(model.features.children())[24:])
    stages = nn.ModuleList([stage1,stage2,stage3,stage4])

    exit4 = NewVgg16Exit(512, n_classes)
    if is_single_exit:
        network = MySENetwork(input, stages, exit4)
    else: 
        exit1 = NewVgg16Exit(128, n_classes)
        exit2 = NewVgg16Exit(256, n_classes)
        exit3 = NewVgg16Exit(512, n_classes)
        exits = nn.ModuleList([exit1,exit2,exit3,exit4])
        network = MyMENetwork(input, stages, exits,[1,2,3,4])

    n_exits = 4

    return network, n_exits

In [19]:
def get_vgg16full(is_single_exit, pretrained, n_classes, img_channels):
    
    model = models.vgg16(pretrained)
    
    input = nn.Conv2d(img_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    stage1 = nn.Sequential(*list(model.features.children())[1:5])
    stage2 = nn.Sequential(*list(model.features.children())[5:10])
    stage3 = nn.Sequential(*list(model.features.children())[10:17])
    stage4 = nn.Sequential(*list(model.features.children())[17:24])
    stage5 = nn.Sequential(*list(model.features.children())[24:])
    stages = nn.ModuleList([stage1,stage2,stage3,stage4,stage5])

    exit5 = NewVgg16Exit(512, n_classes)
    if is_single_exit:
        network = MySENetwork(input, stages, exit5)
    else:
        exit1 = NewVgg16Exit(64, n_classes)
        exit2 = NewVgg16Exit(128, n_classes)
        exit3 = NewVgg16Exit(256, n_classes)
        exit4 = NewVgg16Exit(512, n_classes)
        exits = nn.ModuleList([exit1,exit2,exit3,exit4,exit5])
        network = MyMENetwork(input, stages, exits,[1,2,3,4,5])

    n_exits = 5

    return network, n_exits

### densenet

In [20]:
class NewDensenet169Exit(nn.Module):

    def __init__(self, fetures, n_classes):
        super(NewDensenet169Exit,self).__init__()

        self.fc= nn.Linear(fetures,n_classes)

    def forward(self, x):

        x = F.relu(x, inplace=True)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

In [21]:
def get_densenet169(is_single_exit, pretrained, n_classes, img_channels):
    
    model = models.densenet169(pretrained)
    
    input = nn.Sequential(
        nn.Conv2d(img_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
        nn.Sequential(*list(model.features.children())[1:4])
    )
    stage1 = nn.Sequential(*list(model.features.children())[4:6])
    stage2 = nn.Sequential(*list(model.features.children())[6:8])
    stage3 = nn.Sequential(*list(model.features.children())[8:10])
    stage4 = nn.Sequential(*list(model.features.children())[10:])
    stages = nn.ModuleList([stage1,stage2,stage3,stage4])
    
    exit4 = NewDensenet169Exit(1664,n_classes)        
    if is_single_exit:
        network = MySENetwork(input, stages, exit4)
    else:      
        exit1 = NewDensenet169Exit(128,n_classes)
        exit2 = NewDensenet169Exit(256,n_classes)
        exit3 = NewDensenet169Exit(640,n_classes)
        exits = nn.ModuleList([exit1,exit2,exit3,exit4])
        network = MyMENetwork(input, stages, exits,[1,2,3,4])

    n_exits = 4

    return network, n_exits

### mbv3small

In [22]:
class NewMbv3smallExit(nn.Module):

    def __init__(self, features, n_classes):
        
        super(NewMbv3smallExit,self).__init__()
        
        int1_features= features*6
        int2_features= int(features* 1.75)
        
        self.convnorm = nn.Sequential (
            nn.Conv2d(features, int1_features, kernel_size=(1, 1), stride=(1, 1), bias=False),
            nn.BatchNorm2d(int1_features, eps=0.001, momentum=0.01, affine=True, track_running_stats=True),
            nn.Hardswish(inplace=True)
        )
        
        self.exit = nn.Sequential(
            nn.Linear(int1_features, int2_features, bias=True),
            nn.Hardswish(inplace=True),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(int2_features, n_classes, bias=True)
        )
    
    def forward(self,x):
        x = self.convnorm(x)
        x = F.adaptive_avg_pool2d(x,(1,1))
        x = torch.flatten(x,1)
        x = self.exit(x)
        
        return x


In [23]:
def get_mobilenetv3small(is_single_exit, pretrained, n_classes, img_channels):
    
    model = models.mobilenet_v3_small(pretrained)
        
    input = nn.Sequential(
        nn.Conv2d(img_channels, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
        nn.BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True),
        nn.Hardswish(inplace=True)
    )

    stage1 = nn.Sequential(*list(model.features.children())[1:4])
    stage2 = nn.Sequential(*list(model.features.children())[4:7])
    stage3 = nn.Sequential(*list(model.features.children())[7:9])
    stage4 = nn.Sequential(*list(model.features.children())[9:12])
    stages = nn.ModuleList([stage1,stage2,stage3,stage4])

    exit4 = NewMbv3smallExit(96,n_classes)
    if is_single_exit:
        network = MySENetwork(input, stages, exit4)
    else:
        exit1 = NewMbv3smallExit(24,n_classes)
        exit2 = NewMbv3smallExit(40,n_classes)
        exit3 = NewMbv3smallExit(48,n_classes)
        exits = nn.ModuleList([exit1,exit2,exit3,exit4])
        network = MyMENetwork(input, stages, exits,[1,2,3,4])

    n_exits = 4

    return network, n_exits

In [24]:
def get_mobilenetv3smallfull(is_single_exit, pretrained, n_classes, img_channels):
    
    model = models.mobilenet_v3_small(pretrained)
        
    input = nn.Sequential(
        nn.Conv2d(img_channels, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
        nn.BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True),
        nn.Hardswish(inplace=True)
    )

    stage1 = list(model.features.children())[1]
    stage2 = nn.Sequential(*list(model.features.children())[2:4])
    stage3 = nn.Sequential(*list(model.features.children())[4:7])
    stage4 = nn.Sequential(*list(model.features.children())[7:9])
    stage5 = nn.Sequential(*list(model.features.children())[9:12])
    stages = nn.ModuleList([stage1,stage2,stage3,stage4,stage5])

    exit5 = NewMbv3smallExit(96,n_classes) 
    if is_single_exit:
        network = MySENetwork(input, stages, exit5)
    else:
        exit1 = NewMbv3smallExit(16,n_classes)
        exit2 = NewMbv3smallExit(24,n_classes)
        exit3 = NewMbv3smallExit(40,n_classes)
        exit4 = NewMbv3smallExit(48,n_classes) 
        exits = nn.ModuleList([exit1,exit2,exit3,exit4,exit5])
        network = MyMENetwork(input, stages, exits,[1,2,3,4,5])

    n_exits = 5

    return network, n_exits

### efficientnet

In [25]:
class NewEffB5Exit(nn.Module):

    def __init__(self, features, n_classes):
        
        super(NewEffB5Exit,self).__init__()

        intermediate_features = features * 4
        
        self.convNomrAct = nn.Sequential(
            nn.Conv2d(features, intermediate_features, kernel_size=(1, 1), stride=(1, 1), bias=False),
            nn.BatchNorm2d(intermediate_features, eps=0.001, momentum=0.01, affine=True, track_running_stats=True),
            nn.SiLU(inplace=True)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.4, inplace=True),
            nn.Linear(in_features=intermediate_features, out_features=n_classes, bias=True)
        )

    def forward(self, x):

        x= self.convNomrAct(x)
        x=F.adaptive_avg_pool2d(x,(1,1))
        x = torch.flatten(x, 1)
        x= self.classifier(x)

        return x

In [26]:
def get_efficientnetB5(is_single_exit, pretrained, n_classes, img_channels):

    model = models.efficientnet_b5(pretrained)

    input = nn.Sequential(
        nn.Conv2d(img_channels, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
        nn.BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True),
        nn.SiLU(inplace=True)
    )

    stage1 = nn.Sequential(*list(model.features.children())[1:3])
    stage2 = nn.Sequential(*list(model.features.children())[3:5])
    stage3 = nn.Sequential(*list(model.features.children())[5:6])
    stage4 = nn.Sequential(*list(model.features.children())[6:8])
    stages = nn.ModuleList([stage1,stage2,stage3,stage4])

    exit4 = NewEffB5Exit(512,n_classes)
    if is_single_exit:
        network = MySENetwork(input, stages, exit4)
    else:
        exit1 = NewEffB5Exit(40,n_classes)
        exit2 = NewEffB5Exit(128,n_classes)
        exit3 = NewEffB5Exit(176,n_classes)
        exits = nn.ModuleList([exit1,exit2,exit3,exit4])
        network = MyMENetwork(input, stages, exits,[1,2,3,4])

    n_exits = 4

    return network, n_exits

In [27]:
def get_efficientnetB5full(is_single_exit, pretrained, n_classes, img_channels):

    model = models.efficientnet_b5(pretrained)

    input = nn.Sequential(
        nn.Conv2d(img_channels, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
        nn.BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True),
        nn.SiLU(inplace=True)
    )

    stage1 = nn.Sequential(*list(model.features.children())[1])
    stage2 = nn.Sequential(*list(model.features.children())[2])
    stage3 = nn.Sequential(*list(model.features.children())[3])
    stage4 = nn.Sequential(*list(model.features.children())[4])
    stage5 = nn.Sequential(*list(model.features.children())[5])
    stage6 = nn.Sequential(*list(model.features.children())[6])
    stage7 = nn.Sequential(*list(model.features.children())[7])
    stages = nn.ModuleList([stage1,stage2,stage3,stage4,stage5,stage6,stage7])

    exit7 = NewEffB5Exit(512,n_classes)  
    if is_single_exit:
        network = MySENetwork(input, stages, exit7)
    else:
        exit1 = NewEffB5Exit(24,n_classes)
        exit2 = NewEffB5Exit(40,n_classes)
        exit3 = NewEffB5Exit(64,n_classes)
        exit4 = NewEffB5Exit(128,n_classes)
        exit5 = NewEffB5Exit(176,n_classes)
        exit6 = NewEffB5Exit(304,n_classes)
        exits = nn.ModuleList([exit1,exit2,exit3,exit4,exit5,exit6,exit7])
        network = MyMENetwork(input, stages, exits,[1,2,3,4,5,6,7])

    n_exits = 7

    return network, n_exits

## get networks

In [28]:
def get_network(network_name, is_single_exit, pretrained, n_classes, img_channels):
    
    if network_name == "resnet50":
        network, n_exits = get_resnet50(is_single_exit, pretrained, n_classes, img_channels)
    #------------------------------------------------------------------------------------------------
    elif network_name == "vgg16":
        network, n_exits = get_vgg16(is_single_exit, pretrained, n_classes, img_channels)
    elif network_name == "vgg16full":
        network, n_exits = get_vgg16full(is_single_exit, pretrained, n_classes, img_channels)
    #------------------------------------------------------------------------------------------------  
    elif network_name == "densenet169":
        network, n_exits = get_densenet169(is_single_exit, pretrained, n_classes, img_channels)
    #------------------------------------------------------------------------------------------------
    elif network_name == "mobilenetv3small":
        network, n_exits = get_mobilenetv3small(is_single_exit, pretrained, n_classes, img_channels)
    elif network_name == "mobilenetv3smallfull":
        network, n_exits = get_mobilenetv3smallfull(is_single_exit, pretrained, n_classes, img_channels)
    #------------------------------------------------------------------------------------------------
    elif network_name == "efficientnetB5":
        network, n_exits = get_efficientnetB5(is_single_exit, pretrained, n_classes, img_channels)
    elif network_name == "efficientnetB5full":
        network, n_exits = get_efficientnetB5full(is_single_exit, pretrained, n_classes, img_channels)
    #------------------------------------------------------------------------------------------------
    else:
        raise ValueError("network not implemeneted")
        
    return network, n_exits

## Accuracy metrics

In [29]:
def get_metric_dict():
    return {
        'top1': AverageMeter(),
        'top5': AverageMeter(),
    }


def update_metric(metric_dict, output, labels):
    acc1, acc5 = accuracy(output, labels, topk=(1, 5))
    metric_dict['top1'].update(acc1[0].item(), output.size(0))
    metric_dict['top5'].update(acc5[0].item(), output.size(0))


def get_metric_vals(metric_dict, return_dict=False):
    if return_dict:
        return {key: metric_dict[key].avg for key in metric_dict}
    else:
        return [metric_dict[key].avg for key in metric_dict]

## Early stopping

In [30]:
class EarlyStoppingMeter:

    def __init__(self, patience = 12):
        self.best_loss = None
        self.counter = 0
        self.patience=patience
        self.stop = False

    def __call__(self, val_loss):

        if self.best_loss is None or self.best_loss>val_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter +=1
            if self.counter >= self.patience:
                self.stop=True

# Loggers

In [31]:
def log_output_se(phase, epoch, epochs, loss, top1, top5, time, log_file):
    
    output = phase.upper()+": " #TRAINING/VALIDATION/TESTING
    output += f"Epoch [{epoch}/{epochs}], Loss: {loss:.3f}, top1:{top1:.3f}, top5:{top5:.3f}, time:{time:.1f}s"
    
    
    if phase.lower() == "validation":
        output += "\n--------------------------------------------------------------------------"
    
    print(output)
    
    with open(log_file,"a") as log:
        log.write("\n"+output)

In [32]:
def log_output_ee(phase, epoch, epochs, net_loss, branches_loss, net_top1, net_top5, branches_top1, branches_top5, time, log_file):
    
    output = phase.upper()+": " #TRAINING/VALIDATION/TESTING
    output += f"Epoch [{epoch}/{epochs}], Net Loss: {net_loss:.3f}, Net top1:{net_top1:.3f}, Net top5:{net_top5:.3f}, time:{time:.1f}s\n"
    
    # format the lists of values (branches losses and accuracies)
    br_loss_str= "["
    br_top1_str = "["
    br_top5_str = "["
    for bl,bt1,bt5 in zip(branches_loss, branches_top1, branches_top5):
        br_loss_str += f" {bl:.3f} |"
        br_top1_str += f" {bt1:.3f} |"
        br_top5_str += f" {bt5:.3f} |"
    br_loss_str = br_loss_str[:-1]+"]"
    br_top1_str = br_top1_str[:-1]+"]"
    br_top5_str = br_top5_str[:-1]+"]"
    
    output += "Branches losses = " + br_loss_str +"\n"
    output += "Branches top1 = "+ br_top1_str+"\n"
    output += "Branches top5 = "+ br_top5_str

    if phase.lower() == "validation":
        output += "\n--------------------------------------------------------------------------"
    
    print(output)
    
    with open(log_file,"a") as log:
        log.write("\n"+output)

## PLOTS

In [33]:
def show_plot(epochs,train_values,valid_values,output_path, loss_or_acc, title):
    
    extension = title+".png"
    save_path=os.path.join(output_path,extension)
    
    x = np.arange(1, epochs+1)
    c = "r" if loss_or_acc == "loss" else "b"
    
    plt.plot(x, train_values, color=c, label="train "+ loss_or_acc, linestyle="solid")
    plt.plot(x, valid_values, color=c, label="valid "+ loss_or_acc, linestyle="dashed")
    plt.legend(loc='upper left')

    plt.title(title)
    plt.grid(True)
    
    plt.savefig(save_path)

    plt.show()

# TRAINING AND TESTING

### ENSEMBLE VALIDATION

In [None]:
def fix_len(bin_str,expected_len):
    pad = expected_len - len(bin_str)
    bin_str = "0"*pad + bin_str
    return bin_str

In [None]:
def generate_binary_strings(n):
    bin_str_len=int(math.log2(n))
    bin_str_list =[]
    for i in range(0,n):
        bin_str = "{0:b}".format(i)
        bin_str_list.append(fix_len(bin_str,bin_str_len))
    
    return bin_str_list[1:]

In [None]:
def validate_all_ensembles(network, device, loader, ensemble_weights):

    network.eval()
    with torch.no_grad():
        
        # keep track of acc perfromances for all ensembles
        n_exits = len(ensemble_weights)
        binary_strings=generate_binary_strings(2**n_exits)
        n_ensembles = len(binary_strings)

        ensembles_metric_dicts = []
        for _ in range(n_ensembles):
            ensembles_metric_dicts.append(get_metric_dict())
         
        for _, (images, labels) in enumerate(loader):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = network(images)            
            
            weighted_outputs=[]
            for output,weight in zip(outputs,ensemble_weights):
                weighted_outputs.append(output * weight)

            ensemble_outputs=[]
            for binary_string in binary_strings:
                active_weighted_outputs = []
                for bit, weighted_output in zip(binary_string,weighted_outputs):
                    if bit == "1":
                        active_weighted_outputs.append(weighted_output)

                if len(active_weighted_outputs)>1:
                    ensemble_output=torch.stack(active_weighted_outputs)
                    ensemble_output=torch.sum(ensemble_output,dim=0)
                else:
                    ensemble_output=active_weighted_outputs[0]

                ensemble_outputs.append(ensemble_output)

            # update metrics
            
            for ensemble_output, ensemble_metric_dict in zip(ensemble_outputs, ensembles_metric_dicts):
                update_metric(ensemble_metric_dict, ensemble_output, labels)

        ensembles_dict = {}
        for ensemble_metric_dict, binary_string in zip (ensembles_metric_dicts, binary_strings):
            ensembles_dict[binary_string] = get_metric_vals(ensemble_metric_dict, return_dict=False)
        
        return ensembles_dict

In [None]:
def test_subnet_se(network, device, loader):

    network.eval()
    with torch.no_grad():
        
        metric_dict=get_metric_dict()
        
        start = time.time()
        for _, (images, labels) in enumerate(loader):
            images = images.to(device)
            labels = labels.to(device)
            
            output = network(images)
            
            # update metrics
            update_metric(metric_dict, output, labels)
            
        time_passed= time.time()-start 
        
    return get_metric_vals(metric_dict), time_passed

In [None]:
def test_subnet_ee(network, device, loader, ensemble_weights):

    network.eval()
    with torch.no_grad():
        
        net_metric_dict=get_metric_dict()
        
        start = time.time()
        for _, (images, labels) in enumerate(loader):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = network(images)
                
            weighted_outputs=[]
            for output,weight in zip(outputs,ensemble_weights):
                weighted_outputs.append(output * weight)

            net_output=torch.stack(weighted_outputs)
            net_output=torch.sum(net_output,dim=0)

            # update metrics
            update_metric(net_metric_dict, net_output, labels)

        time_passed= time.time()-start 
        
        return get_metric_vals(net_metric_dict), time_passed

## TEST

In [34]:
def validate_se(network, device, criterion, loader):

    network.eval()
    with torch.no_grad():
        
        losses = AverageMeter()
        metric_dict=get_metric_dict()
        
        start = time.time()
        for _, (images, labels) in enumerate(loader):
            images = images.to(device)
            labels = labels.to(device)
            
            output = network(images)
            loss = criterion(output, labels)
            
            # update metrics
            losses.update(loss.item(), images.size(0))
            update_metric(metric_dict, output, labels)
            
        time_passed= time.time()-start 
        
    return losses.avg, get_metric_vals(metric_dict), time_passed

In [35]:
def validate_ee(network, device, criterion, loader, branches_weights, ensemble_weights):

    network.eval()
    with torch.no_grad():
        
        net_losses = AverageMeter()
        net_metric_dict=get_metric_dict()

        branches_losses_meter = []
        branches_metric_dict = []
        for _ in range(len(branches_weights)):
            branches_losses_meter.append(AverageMeter())
            branches_metric_dict.append(get_metric_dict())
         
        
        start = time.time()
        for _, (images, labels) in enumerate(loader):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = network(images)
            branches_losses = [criterion(output, labels) for output in outputs]

            net_loss = 0
            weighted_losses = []
            for branch_loss, weight in zip(branches_losses,branches_weights):
                wl = branch_loss * weight
                weighted_losses.append(wl)
                net_loss += wl
                

            weighted_outputs=[]
            for output,weight in zip(outputs,ensemble_weights):
                weighted_outputs.append(output * weight)

            net_output=torch.stack(weighted_outputs)
            net_output=torch.sum(net_output,dim=0)

            # update metrics
            net_losses.update(net_loss.item(), images.size(0))
            update_metric(net_metric_dict, net_output, labels)

            for weighted_loss, branch_loss_meter in zip(weighted_losses, branches_losses_meter):
                branch_loss_meter.update(weighted_loss.item(), images.size(0))

            for output, br_metric_dict in zip(outputs, branches_metric_dict):
                update_metric(br_metric_dict, output, labels)


        time_passed= time.time()-start 

        net_losses_avg = net_losses.avg
        net_metric_list_vals = get_metric_vals(net_metric_dict, return_dict=False)
        branches_losses_avgs = [branch_loss_meter.avg for branch_loss_meter in branches_losses_meter]
        branches_metric_list_vals = [get_metric_vals(br_metric_dict) for br_metric_dict in branches_metric_dict]
        br_top1s = [tops[0] for tops in branches_metric_list_vals]
        br_top5s = [tops[1] for tops in branches_metric_list_vals]
        branches_metric_list_vals = [br_top1s, br_top5s]
        
        return net_losses_avg, branches_losses_avgs, net_metric_list_vals, branches_metric_list_vals, time_passed


## TRAIN

In [36]:
def train_one_epoch_se(network, device, criterion, optimizer, train_loader):

    network.train()

    losses = AverageMeter()
    metric_dict=get_metric_dict()
    
    start = time.time()
    for _, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        output = network(images)
        loss = criterion(output, labels)
        
        losses.update(loss.item(), images.size(0))
        update_metric(metric_dict, output, labels)

        #backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    time_passed= time.time()-start 
        
    return losses.avg, get_metric_vals(metric_dict), time_passed

In [37]:
def update_history_se(history,train_loss, train_top1, train_top5, valid_loss, valid_top1, valid_top5):
    
    history["train_loss"].append(train_loss)
    history["train_top1"].append(train_top1)
    history["train_top5"].append(train_top5)
    history["valid_loss"].append(valid_loss)
    history["valid_top1"].append(valid_top1)
    history["valid_top5"].append(valid_top5)

In [38]:
def train_se(network, device, train_criterion, valid_criterion, optimizer, epochs, train_loader, validation_loader, output_path, log_file):
    
    early_stopping = EarlyStoppingMeter()
    best_acc = None
    best_ckpt_pth = os.path.join(output_path,"best_ckpt.pth")
    best_model_pth = os.path.join(output_path,"best_model.pth")

    history = {
        "train_loss":[],
        "train_top1":[],
        "train_top5":[],
        "valid_loss":[],
        "valid_top1":[],
        "valid_top5":[]
    }

    
    for epoch in range(1,epochs+1):

        train_loss, (train_top1, train_top5), train_time = train_one_epoch_se(network, device, train_criterion, optimizer, train_loader)
        log_output_se("training", epoch, epochs, train_loss, train_top1, train_top5, train_time, log_file)
        
        valid_loss, (valid_top1, valid_top5), valid_time = validate_se(network, device, valid_criterion, validation_loader)
        log_output_se("validation", epoch, epochs, valid_loss, valid_top1, valid_top5, valid_time, log_file)

        if best_acc is None or valid_top1>best_acc:
            best_acc = valid_top1
            torch.save(network.state_dict(),best_ckpt_pth)
            torch.save(network,best_model_pth)

        update_history_se(history,train_loss, train_top1, train_top5, valid_loss, valid_top1, valid_top5)

        early_stopping(valid_loss)
        if early_stopping.stop:
            break

    return history, best_model_pth

In [39]:
def train_one_epoch_ee(network, device, criterion, optimizer, train_loader, branches_weights, ensemble_weights):

    network.train()
        
    net_losses = AverageMeter()
    net_metric_dict=get_metric_dict()

    branches_losses_meter = []
    branches_metric_dict = []
    for _ in range(len(branches_weights)):
        branches_losses_meter.append(AverageMeter())
        branches_metric_dict.append(get_metric_dict())
        
    
    start = time.time()
    for _, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = network(images)
        branches_losses = [criterion(output, labels) for output in outputs]

        net_loss = 0
        weighted_losses = []
        for branch_loss, weight in zip(branches_losses,branches_weights):
            wl = branch_loss * weight
            weighted_losses.append(wl)
            net_loss += wl
            

        weighted_outputs=[]
        for output,weight in zip(outputs,ensemble_weights):
            weighted_outputs.append(output * weight)

        net_output=torch.stack(weighted_outputs)
        net_output=torch.sum(net_output,dim=0)

        #backpropagation
        optimizer.zero_grad()
        net_loss.backward()
        optimizer.step()

        # update metrics
        net_losses.update(net_loss.item(), images.size(0))
        update_metric(net_metric_dict, net_output, labels)

        for weighted_loss, branch_loss_meter in zip(weighted_losses, branches_losses_meter):
            branch_loss_meter.update(weighted_loss.item(), images.size(0))

        for output, br_metric_dict in zip(outputs, branches_metric_dict):
            update_metric(br_metric_dict, output, labels)


    time_passed= time.time()-start 

    net_losses_avg = net_losses.avg
    net_metric_list_vals = get_metric_vals(net_metric_dict, return_dict=False)
    branches_losses_avgs = [branch_loss_meter.avg for branch_loss_meter in branches_losses_meter]
    branches_metric_list_vals = [get_metric_vals(br_metric_dict) for br_metric_dict in branches_metric_dict]
    br_top1s = [tops[0] for tops in branches_metric_list_vals]
    br_top5s = [tops[1] for tops in branches_metric_list_vals]
    branches_metric_list_vals = [br_top1s, br_top5s]
    
    return net_losses_avg, branches_losses_avgs, net_metric_list_vals, branches_metric_list_vals, time_passed

In [40]:
def update_history_ee(history,tr_net_loss, tr_branches_losses, tr_net_top1, tr_net_top5, tr_branches_top1,tr_branches_top5,v_net_loss, v_branches_loss, v_net_top1, v_net_top5, v_branches_top1, v_branches_top5):
    
    history["train_net_loss"].append(tr_net_loss)
    history["train_branches_losses"].append(tr_branches_losses)
    history["train_net_top1"].append(tr_net_top1)
    history["train_net_top5"].append(tr_net_top5)
    history["train_branches_top1"].append(tr_branches_top1)
    history["train_branches_top5"].append(tr_branches_top5)
    
    history["valid_net_loss"].append(v_net_loss)
    history["valid_branches_losses"].append(v_branches_loss)
    history["valid_net_top1"].append(v_net_top1)
    history["valid_net_top5"].append(v_net_top5)
    history["valid_branches_top1"].append(v_branches_top1)
    history["valid_branches_top5"].append(v_branches_top5)

In [41]:
def train_ee(network, device, train_criterion, valid_criterion, optimizer, epochs, train_loader, validation_loader, output_path, log_file, branches_weights, ensemble_weights):

    early_stopping = EarlyStoppingMeter()
    best_acc = None
    best_ckpt_pth = os.path.join(output_path,"best_ckpt.pth")
    best_model_pth = os.path.join(output_path,"best_model.pth")

    history = {
        "train_net_loss":[],
        "train_branches_losses":[],
        "train_net_top1":[],
        "train_net_top5":[],
        "train_branches_top1":[],
        "train_branches_top5":[],
        "valid_net_loss":[],
        "valid_branches_losses":[],
        "valid_net_top1":[],
        "valid_net_top5":[],
        "valid_branches_top1":[],
        "valid_branches_top5":[]
    } 
    
    for epoch in range(1,epochs+1):

        tr_net_loss, tr_branches_losses, (tr_net_top1, tr_net_top5), (tr_branches_top1,tr_branches_top5), tr_time = train_one_epoch_ee(network, device, train_criterion, optimizer, train_loader,branches_weights, ensemble_weights)
        log_output_ee("training", epoch, epochs, tr_net_loss, tr_branches_losses, tr_net_top1, tr_net_top5, tr_branches_top1,tr_branches_top5, tr_time, log_file)

        v_net_loss, v_branches_loss, (v_net_top1, v_net_top5), (v_branches_top1, v_branches_top5), v_time = validate_ee(network, device, valid_criterion, validation_loader, branches_weights,ensemble_weights)
        log_output_ee("validation", epoch, epochs, v_net_loss, v_branches_loss, v_net_top1, v_net_top5, v_branches_top1, v_branches_top5, v_time, log_file) 

        if best_acc is None or v_net_top1>best_acc:
            best_acc = v_net_top1
            torch.save(network.state_dict(),best_ckpt_pth)
            torch.save(network,best_model_pth)

        update_history_ee(history,tr_net_loss, tr_branches_losses, tr_net_top1, tr_net_top5, tr_branches_top1,tr_branches_top5,v_net_loss, v_branches_loss, v_net_top1, v_net_top5, v_branches_top1, v_branches_top5)

        early_stopping(v_net_loss)
        if early_stopping.stop:
            break

    return history, best_model_pth

## network metrics

In [42]:
def compute_params(network):
    params = sum(p.numel() for p in network.parameters() if p.requires_grad)
    return  params / 1e6    #Mparams

In [43]:
def compute_macs(network, dummy_data):
    macs = profile_macs(network, dummy_data)
    return macs / 1e6  # in unit of Mmacs

In [44]:
def compute_latency(network, dummy_data):

    iterations = 1000

    cudnn.enabled = True
    cudnn.benchmark = True

    network.eval()
    with torch.no_grad():
        
        for _ in range(100):
            network(dummy_data)

        torch.cuda.synchronize()
        torch.cuda.synchronize()
        
        t_start = time.time()

        for _ in range(iterations):
            network(dummy_data)

        torch.cuda.synchronize()
        torch.cuda.synchronize()

        elapsed_time = time.time() - t_start
        latency = elapsed_time / iterations

    torch.cuda.empty_cache()
    return latency * 1000 # in ms

## Weigths

In [45]:
def get_weights(n_exits,ordering):
    
    if ordering == "UNIF":
        n=1/n_exits
        branches_weights = [round(n,4) for _ in range(0,n_exits)]
        ensemble_weights = branches_weights.copy()

    else:
        m = sum(range(1,n_exits+1))
        n=1/m
        
        branches_weights=[round(n*i,4) for i in range(1,n_exits+1)]
        ensemble_weights = branches_weights.copy()

        if ordering == "MIX" or ordering=="DESC":
            branches_weights=sorted(branches_weights,reverse=True)
            if ordering=="DESC":
                ensemble_weights=sorted(ensemble_weights,reverse=True)

    return branches_weights, ensemble_weights

## RUN 

In [46]:
def run(epochs, batch_size, learning_rate, network_name, se_or_me, weight_ordering, dataset, ft_or_tr, img_size, dataset_save_path, output_path):
    
    ordering = weight_ordering.lower() if se_or_me == "ME" else ""
    notebook_name = "{}-{}{}-{}-{}-{}".format(network_name, se_or_me, ordering, dataset, ft_or_tr, img_size)
    print("Notebook name: " + notebook_name)
    output_path = os.path.join(output_path, notebook_name)
    #------------------------------------------------
    os.makedirs(output_path,exist_ok=True)
    os.makedirs(dataset_save_path,exist_ok=True)
    log_file = os.path.join(output_path,"log.txt")
    #------------------------------------------------
    #set seeds
    seed=420
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    #------------------------------------------------
    train_loader, validation_loader, test_loader, n_classes, img_channels, classes_weights = get_dataloaders(dataset, dataset_save_path, img_size, batch_size)
    #------------------------------------------------
    #set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #------------------------------------------------
    is_single_exit = True if se_or_me == "SE" else False
    pretrained = True if ft_or_tr == "FT" else False
    #------------------------------------------------
    network, n_exits = get_network(network_name, is_single_exit, pretrained, n_classes, img_channels)
    network.to(device)
    #------------------------------------------------
    branches_weights, ensemble_weights = get_weights(n_exits,weight_ordering)
    print("N exits: ", n_exits)
    print("weights ordering: ",weight_ordering)
    print("branches weights: ", branches_weights)
    print("ensemble weights: ", ensemble_weights)
    #------------------------------------------------
    if classes_weights is not None:
        train_criterion = nn.CrossEntropyLoss(weight=classes_weights.to(device))
    else:
        train_criterion = nn.CrossEntropyLoss()

    test_criterion = nn.CrossEntropyLoss()
    
    optimizer = torch.optim.Adam(network.parameters(), learning_rate)
    #------------------------------------------------
    #print(network)
    #------------------------------------------------
    with open(log_file,"w") as log:
        log.write(notebook_name+"\n\n\n")

    training_start = time.time()

    if is_single_exit:
        #train
        history, best_model_pth= train_se(network, device, train_criterion, test_criterion, optimizer, epochs, train_loader, validation_loader, output_path, log_file)
        #plots
        epochs_trained = len(history["train_loss"])
        show_plot(epochs_trained, np.array(history["train_loss"]), np.array(history["valid_loss"]), output_path, "loss", "SE_losses")
        show_plot(epochs_trained, np.array(history["train_top1"]), np.array(history["valid_top1"]), output_path, "acc", "SE_accuracies")

    else:
        #train
        history, best_model_pth = train_ee(network, device, train_criterion, test_criterion, optimizer, epochs, train_loader, validation_loader, output_path, log_file, branches_weights, ensemble_weights)
        #plots
        epochs_trained = len(history["train_net_loss"])
        show_plot(epochs_trained, np.array(history["train_net_loss"]), np.array(history["valid_net_loss"]), output_path, "loss", "EE_net_losses")
        show_plot(epochs_trained, np.array(history["train_net_top1"]), np.array(history["valid_net_top1"]), output_path, "acc", "EE_net_accuracies")
        for i in range(len(branches_weights)):
            show_plot(epochs_trained, np.array(history["train_branches_losses"])[:,i], np.array(history["valid_branches_losses"])[:,i], output_path, "loss", "EE_br"+str(i+1)+"_losses")
            show_plot(epochs_trained, np.array(history["train_branches_top1"])[:,i], np.array(history["valid_branches_top1"])[:,i], output_path, "acc", "EE_br"+str(i+1)+"_acc")

    training_duration = time.time()-training_start
    tds = f"training took {training_duration/3600:.2f}h or {training_duration/60:.2f}mins or {training_duration:.1f}s"
    #------------------------------------------------
    network = torch.load(best_model_pth)
    network.to(device)

    if is_single_exit:
        test_loss, (test_top1, test_top5), test_time = validate_se(network, device, test_criterion, test_loader)
        log_output_se("testing", 1, 1, test_loss, test_top1, test_top5, test_time, log_file)
    else:
        test_net_loss, test_branches_losses, (test_net_top1,test_net_top5), (test_branches_top1, test_branches_top5), test_time = validate_ee(network, device, test_criterion, test_loader, branches_weights, ensemble_weights)
        log_output_ee("testing", 1, 1, test_net_loss, test_branches_losses, test_net_top1, test_net_top5, test_branches_top1, test_branches_top5, test_time, log_file)
    #------------------------------------------------
    input_size = (1, img_channels, img_size, img_size)
    dummy_data = torch.rand(*input_size).to(device)

    Mparams = compute_params(network)
    Mmacs = compute_macs(network, dummy_data)
    latency_ms = compute_latency(network, dummy_data)

    net_eval_str = "\n\n"+tds+"\n"
    net_eval_str += f"#Parameters: {Mparams:.3f}M\n"
    net_eval_str += f"#Macs: {Mmacs:.3f}M\n"
    net_eval_str += f"Latency: {latency_ms:.3f} ms"

    print(net_eval_str)
    with open(log_file,"a") as log:
        log.write("\n"+net_eval_str)


    #------------------------------------------------
    #------------------------------------------------
    #------------------------------------------------
    #   ENSEMBLE SELECTION 
    #------------------------------------------------
    #------------------------------------------------
    #------------------------------------------------

    ens_log_file = os.path.join(output_path,"log_ensemble.txt")
    with open(log_file,"w") as ens_log:
        ens_log.write(notebook_name+" ensembles"+"\n\n\n")
    
    best_model = network

    ensembles_results = validate_all_ensembles(best_model, device, validation_loader, ensemble_weights)
    best_ens_key = max(ensembles_results, key=lambda key: ensembles_results[key])

    with open(ens_log_file,"a") as ens_log:
        ens_log.write("ENSEMBLES on VALIDATION:\n")
        for k,v in ensembles_results.items():
            ens_log.write(f"{k}: [{v[0]:.3f}, {v[1]:.3f}]\n")
        ens_log.write("\n\n")

    best_ens_active_exits = []
    best_ens_active_weights = []
    for n,(bit, ens_weight) in enumerate(zip(best_ens_key, ensemble_weights),1):
        if bit=="1":
            best_ens_active_exits.append(n)
            best_ens_active_weights.append(ens_weight)

    log_best_ens_active_exits = "["+ ",".join(map(str,best_ens_active_exits))+"]"
    log_best_ens_active_weights = "["+ ",".join(map(str,best_ens_active_weights))+"]"
    
    with open(ens_log_file,"a") as ens_log:
        ens_log.write(f"BEST:    key: {best_ens_key}, Top1: {ensembles_results[best_ens_key][0]:.3f}, Top5: {ensembles_results[best_ens_key][1]:.3f}\n")
        ens_log.write(f"Active exits: "+log_best_ens_active_exits+"\n")
        ens_log.write(f"Ensemble weights: "+log_best_ens_active_weights+"\n\n\n")

    best_ensemble_net = best_model.extract_subnetwork(best_ens_active_exits)
    torch.save(best_ensemble_net,os.path.join(output_path,"best_subnet.pth"))

    if isinstance(best_ensemble_net,MySENetwork):
        (test_top1, test_top5), test_time = test_subnet_se(best_ensemble_net, device, test_loader)
    else:
        (test_top1, test_top5), test_time = test_subnet_ee(best_ensemble_net, device, test_loader, best_ens_active_weights)
    
    with open(ens_log_file,"a") as ens_log:
        ens_log.write(f"TEST: top1:{test_top1:.3f}, top5:{test_top5:.3f}, time:{test_time:.1f}s")

    input_size = (1, img_channels, img_size, img_size)
    dummy_data = torch.rand(*input_size).to(device)

    Mparams = compute_params(best_ensemble_net)
    Mmacs = compute_macs(best_ensemble_net, dummy_data)
    latency_ms = compute_latency(best_ensemble_net, dummy_data)

    net_eval_str = f"#Parameters: {Mparams:.3f}M\n"
    net_eval_str += f"#Macs: {Mmacs:.3f}M\n"
    net_eval_str += f"Latency: {latency_ms:.3f} ms"

    with open(ens_log_file,"a") as ens_log:
        ens_log.write("\n\n"+net_eval_str)

# START

In [47]:
epochs=100
batch_size=64
learning_rate=1e-4

In [None]:
dataset_save_path = "./datasets"
output_path = "./outputs"
weights_path = "./weights"

torch.hub.set_dir(weights_path)

network_names = ["resnet50","vgg16","vgg16full","densenet169","mobilenetv3small","mobilenetv3smallfull","efficientnetB5","efficientnetB5full"]
ses_or_mes = ["SE","ME"]
weights_orderings = ["DESC","ASC", "MIX","UNIF"]
datasets = ["cifar10","cifar100","eurosatW","fashionMNIST","gtsrbW","tiny_imagenet"]
fts_or_trs = ["TR","FT"]
img_sizes = [64,224]

for ft_or_tr in fts_or_trs:
    for img_size in img_sizes:
        for network_name in network_names:
            for se_or_me in ses_or_mes:
                for weight_ordering in weights_orderings:
                    for dataset in datasets:
                    
                        if se_or_me == "SE" and (weight_ordering!="DESC" or "full" in network_names):
                            pass
                        else:
                            run(epochs, batch_size, learning_rate, network_name, se_or_me, weight_ordering, dataset, ft_or_tr, img_size, dataset_save_path, output_path)
                            