In [5]:
import ast
import random
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

class EncodingBase:
    """This class manages various representations of the sub-network configuration. It processes
    a user parameter dictionary to create the following representations:
    """

    def __init__(self, param_dict: dict, verbose: bool = False, seed: int = 0):
        self.param_dict = param_dict
        self.verbose = verbose
        self.mapper, self.param_upperbound, self.param_count = self.process_param_dict()
        #self.set_seed(seed)
        #self.inv_mapper = self.create_inv_mapper()

    def process_param_dict(self) -> Tuple[List[Dict[int, Union[int, float]]], List[int], int]:
        """Builds a parameter mapping arrays and an upper-bound vector for PyMoo."""
        parameter_count = 0
        parameter_bound = list()
        parameter_upperbound = list()
        parameter_mapper = list()

        for parameter, options in self.param_dict.items():
            # How many variables should be searched for
            parameter_count += options['count']
            parameter_bound.append(options['count'])

            # How many variables for each parameter
            for i in range(options['count']):
                parameter_upperbound.append(len(options['vars']) - 1)
                single_mapping = dict()
                for idx, value in enumerate(options['vars']):
                    if type(value) == int or type(value) == float:
                        single_mapping[idx] = value
                    else:
                        single_mapping[idx] = str(value)

                parameter_mapper.append(single_mapping)

        return parameter_mapper, parameter_upperbound, parameter_count

In [6]:
ofa_resnet50={
        'd': {'count': 5, 'vars': [0, 1, 2]},
        'e': {'count': 18, 'vars': [0.2, 0.25, 0.35]},
        'w': {'count': 6, 'vars': [0, 1, 2]},
    }

# parameter의 count의 의미는 무엇인가?

In [7]:
ofa_resnet50_encoding = EncodingBase(ofa_resnet50)

In [8]:
ofa_resnet50_encoding.param_upperbound

[2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2]

In [11]:
ofa_resnet50_encoding.param_count

29