In [2]:
from copy import deepcopy
from functools import reduce
from typing import Any, Callable, Dict, List, Tuple, TypeVar

import torch
import torch.nn as nn

from nncf import NNCFConfig
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elasticity_controller import ElasticityController
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elasticity_dim import ElasticityDim
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.multi_elasticity_handler import SubnetConfig
from nncf.experimental.torch.nas.bootstrapNAS.training.model_creator_helpers import resume_compression_from_state
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.model_creation import create_nncf_network
from nncf.torch.nncf_network import NNCFNetwork


INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch


In [14]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"


In [3]:
TModel = TypeVar("TModel")
ValFnType = Callable[[NNCFNetwork, Any], Any]


In [4]:
class TrainedSuperNet:
    """
    An interface for handling pre-trained super-networks. This class can be used to quickly implement
    third party solutions for subnetwork search on existing super-networks.
    """

    def __init__(self, elastic_ctrl: ElasticityController, nncf_network: NNCFNetwork, original_torch_model: nn.Module):
        """
        Initializes the super-network interface.

        :param elastic_ctrl: Elasticity controller to activate subnetworks
        :param nncf_network: NNCFNetwork that wraps the original PyTorch model.
        """
        self._m_handler = elastic_ctrl.multi_elasticity_handler
        self._elasticity_ctrl = elastic_ctrl
        self._model = nncf_network
        self._original_torch_model = original_torch_model

    @classmethod
    def from_checkpoint(
        cls,
        model: TModel,
        nncf_config: NNCFConfig,
        supernet_elasticity_path: str,
        supernet_weights_path: str,
    ) -> "TrainedSuperNet":
        """
        Loads existing super-network weights and elasticity information, and creates the SuperNetwork interface.

        :param model: base model that was used to create the super-network.
        :param nncf_config: configuration used to create the super-network.
        :param supernet_elasticity_path: path to file containing state information about the super-network.
        :param supernet_weights_path: trained weights to resume the super-network.
        :return: SuperNetwork with wrapped functionality.
        """
        original_torch_model = deepcopy(model)
        nncf_network = create_nncf_network(model, nncf_config)
        compression_state = torch.load(supernet_elasticity_path, map_location=torch.device(nncf_config.device))
        model, elasticity_ctrl = resume_compression_from_state(nncf_network, compression_state)
        model_weights = torch.load(supernet_weights_path, map_location=torch.device(nncf_config.device))
        load_state(model, model_weights, is_resume=True)
        elasticity_ctrl.multi_elasticity_handler.activate_maximum_subnet()
        return TrainedSuperNet(elasticity_ctrl, model, original_torch_model)

    def get_search_space(self) -> Dict:
        """
        :return: dictionary with possible values for elastic configurations.
        """
        return self._m_handler.get_search_space()

    def get_design_vars_info(self) -> Tuple[int, List[int]]:
        """
        :return: number of possible values in subnet configurations and
        the number of possible values for each elastic property.
        """
        self._m_handler.get_design_vars_info()

    def eval_subnet_with_design_vars(self, design_config: List, eval_fn: ValFnType, **kwargs) -> Any:
        """

        :return: the value produced by the user's function to evaluate the subnetwork.
        """
        self._m_handler.activate_subnet_for_config(self._m_handler.get_config_from_pymoo(design_config))
        return eval_fn(self._model, **kwargs)

    def eval_active_subnet(self, eval_fn: ValFnType, **kwargs) -> Any:
        """
        :param eval_fn: user's function to evaluate the active subnetwork.
        :return: value of the user's function used to evaluate the subnetwork.
        """
        return eval_fn(self._model, **kwargs)

    def eval_subnet(self, config: SubnetConfig, eval_fn: ValFnType, **kwargs) -> Any:
        """
        :param config: subnetwork configuration.
        :param eval_fn: user's function to evaluate the active subnetwork.
        :return: value of the user's function used to evaluate the subnetwork.
        """
        self.activate_config(config)
        return self.eval_active_subnet(eval_fn, **kwargs)

    def activate_config(self, config: SubnetConfig) -> None:
        """
        :param config: subnetwork configuration to activate.
        """
        self._m_handler.activate_subnet_for_config(config)

    def activate_maximal_subnet(self) -> None:
        """
        Activates the maximal subnetwork in the super-network.
        """
        self._m_handler.activate_maximum_subnet()

    def activate_minimal_subnet(self) -> None:
        """
        Activates the minimal subnetwork in the super-network.
        """
        self._m_handler.activate_minimum_subnet()

    def get_active_config(self) -> SubnetConfig:
        """
        :return: the active configuration.
        """
        return self._m_handler.get_active_config()

    def get_macs_for_active_config(self) -> float:
        """
        :return: MACs of active subnet.
        """
        return self._m_handler.count_flops_and_weights_for_active_subnet()[0] / 2e6

    def export_active_subnet_to_onnx(self, filename: str = "subnet") -> None:
        """
        Exports the active subnetwork to ONNX format.

        :param filename: name of the output file.
        """
        self._elasticity_ctrl.export_model(f"{filename}.onnx")

    def get_config_from_pymoo(self, pymoo_config: List) -> SubnetConfig:
        """
        Converts a Pymoo subnetwork configuration into a SubnetConfig.

        :param pymoo_config: subnetwork configuration in Pymoo format.
        :return: subnetwork configuration in SubnetConfig format.
        """
        return self._m_handler.get_config_from_pymoo(pymoo_config)

    def get_active_subnet(self) -> NNCFNetwork:
        """
        :return: the nncf network with the current active configuration.
        """
        return self._model

    @torch.no_grad()
    def get_clean_subnet(self) -> nn.Module:
        """
        Remove pre-ops and post-ops by directly pruning weights shape. Returns a subnet without NNCF wrappers.
        """

        def get_module_by_name(model_, access_string):
            names = access_string.split(sep=".")
            return reduce(getattr, names, model_)

        config = self.get_active_config()
        subnet_model = deepcopy(self._model)
        torch_model = deepcopy(self._original_torch_model)

        # elastic width - update weight width
        if ElasticityDim.WIDTH in config:
            for cluster_id, _ in config[ElasticityDim.WIDTH].items():
                cluster = self._m_handler.width_handler._pruned_module_groups_info.get_cluster_by_id(cluster_id)
                for elastic_width_info in cluster.elements:
                    node_module = self._model.nncf.get_containing_module(elastic_width_info.node_name)
                    subnet_module = subnet_model.nncf.get_containing_module(elastic_width_info.node_name)
                    for op_id in node_module.pre_ops.keys():
                        node_module.pre_ops[op_id].op.get_clean_subnet_weight(subnet_module)

            for (
                node_name,
                dynamic_input_width_op,
            ) in self._m_handler.width_handler._node_name_vs_dynamic_input_width_op_map.items():
                subnet_module = subnet_model.nncf.get_containing_module(node_name)
                dynamic_input_width_op.get_clean_subnet_weight(subnet_module)

            # update weights in torch model (now only supports transformers)
            for pt_name, pt_module in torch_model.named_modules():
                if isinstance(pt_module, nn.Linear):
                    nncf_module = get_module_by_name(subnet_model, pt_name)
                    pt_module.in_features = nncf_module.in_features
                    pt_module.out_features = nncf_module.out_features
                    pt_module.weight = deepcopy(nncf_module.weight)
                    pt_module.bias = deepcopy(nncf_module.bias)
                if isinstance(pt_module, nn.Embedding):
                    nncf_module = get_module_by_name(subnet_model, pt_name)
                    pt_module.weight = deepcopy(nncf_module.weight)
                    pt_module.embedding_dim = nncf_module.embedding_dim
                if isinstance(pt_module, nn.LayerNorm):
                    nncf_module = get_module_by_name(subnet_model, pt_name)
                    pt_module.weight = deepcopy(nncf_module.weight)
                    pt_module.bias = deepcopy(nncf_module.bias)
                    pt_module.normalized_shape = nncf_module.normalized_shape

        # elastic depth - replace with identity
        if ElasticityDim.DEPTH in config or ElasticityDim.KERNEL in config:
            raise NotImplementedError

        return torch_model



# Test the APIs

In [5]:
from pathlib import Path
import jstyleson as json
from transformers import AutoModelForQuestionAnswering
from nncf.common.utils.os import safe_open


  from .autonotebook import tqdm as notebook_tqdm


## Load HuggingFace BERT

In [6]:
from copy import deepcopy
model_name = "bert-large-uncased-whole-word-masking"
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
original_model = deepcopy(model)

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-large-uncased-whole-word-mas

## Prepare nncf model and configurations

In [9]:
supernet_weights_path = "./supernet.bin"
supernet_elasticity_path = (
    "./elasticity_state.pt"
)
nncf_config_path = "./nncf_config.json"
nncf_config_path = Path(nncf_config_path).resolve()


Load nncf configurations

In [10]:
with safe_open(nncf_config_path) as f:
    loaded_json = json.load(f)
nncf_config = NNCFConfig.from_dict(loaded_json)
nncf_config.device = "cuda" if torch.cuda.is_available() else "cpu"


## Load the TrainedSuperNet from checkpoint


In [11]:
train_supernet = TrainedSuperNet.from_checkpoint(
    model, nncf_config, supernet_elasticity_path, supernet_weights_path
)

INFO:nncf:Loaded 392/392 parameters


## Activate minimal subnet
### Check the basic info of the active minimal subnet
1. Get the minimal subnet configuration
2. Print the active minimal subnet

In [16]:
train_supernet.activate_minimal_subnet()
print(f"current config: {train_supernet.get_active_config()}")

dash_line = "-" * 80
print(dash_line)

active_subnet = train_supernet.get_active_subnet()
print(f"current subnet: {active_subnet}")

current config: OrderedDict([(<ElasticityDim.WIDTH: 'width'>, {0: 256, 1: 320, 2: 192, 3: 192, 4: 192, 5: 64, 6: 256, 7: 320, 8: 256, 9: 128, 10: 256, 11: 384, 12: 448, 13: 448, 14: 512, 15: 576, 16: 704, 17: 576, 18: 576, 19: 640, 20: 512, 21: 256, 22: 192, 23: 128, 24: 1729, 25: 1573, 26: 1681, 27: 1685, 28: 1727, 29: 1879, 30: 1836, 31: 1938, 32: 1764, 33: 1692, 34: 1631, 35: 1539, 36: 1513, 37: 1459, 38: 1259, 39: 1118, 40: 1179, 41: 1063, 42: 770, 43: 523, 44: 264, 45: 184, 46: 133, 47: 127})])
--------------------------------------------------------------------------------
current subnet: BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): NNCFEmbedding(
        30522, 1024, padding_idx=0
        (pre_ops): ModuleDict()
        (post_ops): ModuleDict()
      )
      (position_embeddings): NNCFEmbedding(
        512, 1024
        (pre_ops): ModuleDict()
        (post_ops): ModuleDict()
      )
      (token_type_embeddings): NNC

## Question 1: What's the mechanics behind the active subnet?
Get the trainable model parameter of active subnet

I checked the trainable parameters, looks like all the gradient of supernet are activated 

I guess NNCF use the weights masking 


In [17]:
print(print_number_of_trainable_model_parameters(active_subnet))

trainable model parameters: 334094338
all model parameters: 334094338
percentage of trainable model parameters: 100.00%


## Active Largest Subnet

In [19]:
train_supernet.activate_maximal_subnet()
print(f"current config: {train_supernet.get_active_config()}")

print(dash_line)

active_subnet_max = train_supernet.get_active_subnet()
print(f"current subnet: {active_subnet_max}")

current config: OrderedDict([(<ElasticityDim.WIDTH: 'width'>, {0: 1024, 1: 1024, 2: 1024, 3: 1024, 4: 1024, 5: 1024, 6: 1024, 7: 1024, 8: 1024, 9: 1024, 10: 1024, 11: 1024, 12: 1024, 13: 1024, 14: 1024, 15: 1024, 16: 1024, 17: 1024, 18: 1024, 19: 1024, 20: 1024, 21: 1024, 22: 1024, 23: 1024, 24: 4096, 25: 4096, 26: 4096, 27: 4096, 28: 4096, 29: 4096, 30: 4096, 31: 4096, 32: 4096, 33: 4096, 34: 4096, 35: 4096, 36: 4096, 37: 4096, 38: 4096, 39: 4096, 40: 4096, 41: 4096, 42: 4096, 43: 4096, 44: 4096, 45: 4096, 46: 4096, 47: 4096})])
--------------------------------------------------------------------------------
current subnet: BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): NNCFEmbedding(
        30522, 1024, padding_idx=0
        (pre_ops): ModuleDict()
        (post_ops): ModuleDict()
      )
      (position_embeddings): NNCFEmbedding(
        512, 1024
        (pre_ops): ModuleDict()
        (post_ops): ModuleDict()
      )
   

In [20]:
print(print_number_of_trainable_model_parameters(active_subnet_max))

trainable model parameters: 334094338
all model parameters: 334094338
percentage of trainable model parameters: 100.00%


## Observation 1: on minimal subnet and maximal subnet
1. The configurations is different
2. The activated weights are different, but the state dict are the same, so we can directly do model aggregations on different subnet (need verify)

## Question 2: Save the model and load the state dict
In this section, I'll save the maximal and minimal subnetwork model states to disk.

Then I'll load it to memory

In [24]:
print(dash_line)
print("Saving minimal subnet")
save_path = "./monimal_subnet.pth"
active_subnet.save_pretrained(save_path)
print("Done")

--------------------------------------------------------------------------------
Saving minimal subnet
Done


Then, load the subnet from disk

### Loading Method 1 (fails):
Use .from_pretained in nncf object 

In [25]:
print(dash_line)
print("Loading model from: {}".format(save_path))
active_subnet.from_pretrained(save_path)

--------------------------------------------------------------------------------
Loading model from: ./monimal_subnet.pth


AttributeError: 'NNCFNetworkInterface' object has no attribute '_original_unbound_forward'

### Loading Method 2 (Succeed):
Use .from_pretained in transformer object 

#### Question on this method
1. It lost all the subnetwork information.
2. Can I have a save method that save the elasticity and config as well?
3. I can directly aggregate this model weights, and then convert to supernet, and further sample subnetworks for next FL round
4. I need convert this model to nncf instance

In [27]:
model = AutoModelForQuestionAnswering.from_pretrained(save_path)
print(model)

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), ep

### Loading Method 3 (Fails):
Use torch.load, It can only load weight tensor, not the entire model


In [30]:
import os
model = torch.load(os.path.join(save_path, 'pytorch_model.bin'))
print(model)

{'bert.embeddings.position_ids': tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
         154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
         168, 169, 

### Method 4: use torch.save and torch.load(Fails)

In [32]:
pth_save_path = "./pth_saved_model"
torch.save(active_subnet,pth_save_path)
new_net = torch.load(pth_save_path)

PicklingError: Can't pickle <class 'transformers.models.bert.modeling_bert.BertForQuestionAnswering'>: it's not the same object as transformers.models.bert.modeling_bert.BertForQuestionAnswering

## Model weights aggregation

Note: **Position Embedding shouldnt aggregate (Very important!)**

In [40]:

net1_state_dict = active_subnet_max.state_dict()
net2_state_dict = active_subnet.state_dict()

#aggregate net1 and net2 state_dicts
for key in net1_state_dict.keys():
    if 'weight' or 'bias' in key:
        net1_state_dict[key] = net1_state_dict[key] + net2_state_dict[key]



In [41]:
print(net1_state_dict)


OrderedDict([('bert.embeddings.position_ids', tensor([[   0,    2,    4,    6,    8,   10,   12,   14,   16,   18,   20,   22,
           24,   26,   28,   30,   32,   34,   36,   38,   40,   42,   44,   46,
           48,   50,   52,   54,   56,   58,   60,   62,   64,   66,   68,   70,
           72,   74,   76,   78,   80,   82,   84,   86,   88,   90,   92,   94,
           96,   98,  100,  102,  104,  106,  108,  110,  112,  114,  116,  118,
          120,  122,  124,  126,  128,  130,  132,  134,  136,  138,  140,  142,
          144,  146,  148,  150,  152,  154,  156,  158,  160,  162,  164,  166,
          168,  170,  172,  174,  176,  178,  180,  182,  184,  186,  188,  190,
          192,  194,  196,  198,  200,  202,  204,  206,  208,  210,  212,  214,
          216,  218,  220,  222,  224,  226,  228,  230,  232,  234,  236,  238,
          240,  242,  244,  246,  248,  250,  252,  254,  256,  258,  260,  262,
          264,  266,  268,  270,  272,  274,  276,  278,  280, 

In [42]:
print(net2_state_dict)


OrderedDict([('bert.embeddings.position_ids', tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
         154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
      

In [45]:
print(active_subnet.load_state_dict(net1_state_dict))
print(active_subnet.state_dict())

<All keys matched successfully>
OrderedDict([('bert.embeddings.position_ids', tensor([[   0,    2,    4,    6,    8,   10,   12,   14,   16,   18,   20,   22,
           24,   26,   28,   30,   32,   34,   36,   38,   40,   42,   44,   46,
           48,   50,   52,   54,   56,   58,   60,   62,   64,   66,   68,   70,
           72,   74,   76,   78,   80,   82,   84,   86,   88,   90,   92,   94,
           96,   98,  100,  102,  104,  106,  108,  110,  112,  114,  116,  118,
          120,  122,  124,  126,  128,  130,  132,  134,  136,  138,  140,  142,
          144,  146,  148,  150,  152,  154,  156,  158,  160,  162,  164,  166,
          168,  170,  172,  174,  176,  178,  180,  182,  184,  186,  188,  190,
          192,  194,  196,  198,  200,  202,  204,  206,  208,  210,  212,  214,
          216,  218,  220,  222,  224,  226,  228,  230,  232,  234,  236,  238,
          240,  242,  244,  246,  248,  250,  252,  254,  256,  258,  260,  262,
          264,  266,  268,  270

# Get Clean Network

In [48]:
train_supernet.activate_minimal_subnet()
clean_subnet = train_supernet.get_clean_subnet()
print(clean_subnet)

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=256, bias=True)
              (key): Linear(in_features=1024, out_features=256, bias=True)
              (value): Linear(in_features=1024, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-12, ele

In [50]:
train_supernet.activate_maximal_subnet()
clean_subnet = train_supernet.get_clean_subnet()
print(clean_subnet)

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), ep

## Others

In [51]:
space = train_supernet.get_search_space()
print(space)

{'width': {0: [1024, 320, 256], 1: [1024, 384, 320], 2: [1024, 256, 192], 3: [1024, 256, 192], 4: [1024, 256, 192], 5: [1024, 128, 64], 6: [1024, 320, 256], 7: [1024, 384, 320], 8: [1024, 320, 256], 9: [1024, 192, 128], 10: [1024, 320, 256], 11: [1024, 448, 384], 12: [1024, 512, 448], 13: [1024, 512, 448], 14: [1024, 576, 512], 15: [1024, 640, 576], 16: [1024, 768, 704], 17: [1024, 640, 576], 18: [1024, 640, 576], 19: [1024, 704, 640], 20: [1024, 576, 512], 21: [1024, 320, 256], 22: [1024, 256, 192], 23: [1024, 192, 128], 24: [4096, 1729], 25: [4096, 1573], 26: [4096, 1681], 27: [4096, 1685], 28: [4096, 1727], 29: [4096, 1879], 30: [4096, 1836], 31: [4096, 1938], 32: [4096, 1764], 33: [4096, 1692], 34: [4096, 1631], 35: [4096, 1539], 36: [4096, 1513], 37: [4096, 1459], 38: [4096, 1259], 39: [4096, 1118], 40: [4096, 1179], 41: [4096, 1063], 42: [4096, 770], 43: [4096, 523], 44: [4096, 264], 45: [4096, 184], 46: [4096, 133], 47: [4096, 127]}}


In [53]:
vars = train_supernet.get_design_vars_info()
print(vars)

None


In [56]:
train_supernet.activate_maximal_subnet()
MACs = train_supernet.get_macs_for_active_config()
print(dash_line)
print(f"maximal subnet: {MACs} MACs")
train_supernet.activate_minimal_subnet()
MACs = train_supernet.get_macs_for_active_config()
print(dash_line)
print(f"minimal subnet: {MACs} MACs")

IndexError: index out of range in self