## Space-generator - design network candidate based on user inputs

### Load Macro Generator

In [None]:
import sys
sys.path.insert(0, '../')
sys.path.insert(0, '../../')
import os
os.environ['CUDA_VISIBLE_DEVICES']='0, 1'
os.environ['TOKENIZERS_PARALLELISM']='false'

import numpy as np
import random
np.random.seed(1)
random.seed(1)

In [None]:
import logging
logging.basicConfig(
        format="%(asctime)s — %(levelname)s — %(name)s — %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
    )

from spacegen import spacegen
from network_designer.graph_generator.generator import GraphGenerator

In [None]:
import re
from network_designer.design_space.extend.mastermodel import MasterModel
from ptflops import get_model_complexity_info
import math
from network_designer.dataloader import define_dataloader
dataloader, n_classes, input_size = define_dataloader(dataset='cifar10', dataset_path='../data')

In [None]:
class HierachicalGenerator():
    def __init__(self, 
                 seq_gen="./trained_model_arch", 
                 gvae="../experiments/extend/step_1/m1_dist_1.0_4_512_256.pt", 
                 ccnf="../experiments/extend/cnf_.pt", 
                 gmm="../ckpt/gmm.pkl"):
        self.seq_gen = spacegen(model_folder=seq_gen, to_gpu=True)
        self.seq_gen.padding_side = 'left'
        
        self.g_gen = GraphGenerator(gvae, ccnf, gmm)
        self.g_gen.load_components()
        
    def metrics_validate(self, dict1, dict2):
        print(dict1)
        print(dict2)
        for key in dict1:
            if key in dict2:
                # For 'param', check if the absolute difference is less than 1e6
                if key == 'param':
                    if dict1[key] > dict2[key]*1.2:
                        return False
                        
                # For 'zc_score', check if the difference percentage is less than 20
                elif key == 'zc_score':
                    difference_percentage = abs(dict1[key] - dict2[key]) / dict2[key] * 100
                    if difference_percentage > 10:
                        return False
                        
                # For 'flops', check if the absolute difference is less than 50e6
                elif key == 'flops':
                    if dict1[key] > dict2[key]*1.2:
                        return False
        
        return True
        
    def analysis_prompt(self, prompt):
        # Initialize the result dictionary
        result = {}
        
        # Define a regular expression pattern to extract key-value pairs
        pattern = re.compile(r'(\w+):([\d.]+[kM]?)')
        
        # Find all matches of the pattern in the input_string
        matches = pattern.findall(prompt)
        
        # Iterate over the matches and process each one
        for key, value in matches:
            # Check if the value contains a character denoting scale (k, M, etc.)
            if 'k' in value:
                # Convert kilo to a float
                numeric_value = float(value.replace('k', '')) * 1e3
            elif 'M' in value:
                # Convert Mega to a float
                numeric_value = float(value.replace('M', '')) * 1e6
            else:
                # If no scale character, simply convert the value to a float
                numeric_value = float(value)
            
            # Add the key and the converted value to the result dictionary
            result[key] = numeric_value
        
        return result

    def extract_cells_info(self, sequence):
        # Find the substring between (inputs), and outputs
        start = sequence.find('(inputs),') + len('(inputs),')
        end = sequence.find(',outputs')
        substring = sequence[start:end]
        
        # Find all cell information within the substring
        cell_infos = re.findall(r'\((.*?)\)', substring)
        
        # Initialize lists to store cell_type, num_channels, stride, and repeats
        cell_types = []
        num_channels = []
        strides = []
        repeats = []
        
        # Iterate through each cell_info and extract information
        for cell_info in cell_infos:
            cell_info_parts = cell_info.split(', ')
            cell_types.append(int(cell_info_parts[0].split('_')[-1]))
            num_channels.append(int(cell_info_parts[1]))
            strides.append(int(cell_info_parts[2]))
            repeats.append(int(cell_info_parts[3]))
        
        return cell_types, num_channels, strides, repeats
    
    def generate(self, prompt, n, gpu=1, dataloader=None, patiens=100):
        model_list = []
        conditions = self.analysis_prompt(prompt)
        conditions['num_classes'] = int(conditions['num_classes'])
        conditions['input_size'] = int(conditions['input_size'])
        print(conditions)
        if conditions['num_classes'] >= 1000:
            stem_stride = 2
        else:
            stem_stride = 1
        
        macro_seqeuence = self.seq_gen.generate(n=n,
            prompt=prompt,
            max_length=256,
            temperature=0.01,
            return_as_list=True)
        
        for seq in macro_seqeuence:
            print(seq)
            cell_types, num_channels, strides, repeats = self.extract_cells_info(seq)
            print(cell_types)
            print(num_channels)
            print(strides)
            print(repeats)
            time_out = patiens
            while time_out >= 0:
                print(time_out)
                adjs = []
                opss = []
                for c in cell_types:
                    a, o = self.g_gen.generate_with_ref_subsetid(ref_id=c)
                    adjs.append(a)
                    opss.append(o)
                    
                model = MasterModel(num_channels, strides, repeats, adjs=adjs, opss=opss, num_classes=conditions['num_classes'], stem_stride=stem_stride).to(device=1)
                
                measurements = {}
                #if network is not valid to run skip and continue
                try:
                    macs, params = get_model_complexity_info(model, input_size, as_strings=False, print_per_layer_stat=False)
                    measurements['flops'] = float(macs) * 2
                    measurements['param'] = float(params)
                except:
                    time_out -= 1
                    continue
                
                try:
                    model.get_zc_info(dataloader, gpu=gpu, micro=False, input_size=input_size)
                    if math.isnan(model.score[0]):
                        time_out -= 1
                        continue
                    measurements['zc_score'] = model.score[0]
                except:
                    time_out -= 1
                    continue
                
                # if self.metrics_validate(measurements, conditions):
                print("---------Success----------------")
                model_list.append(model)
                print(measurements)
                print(conditions)
                time_out -=1
                    # break
                # else:
                #     print("---------False------------------")
                #     time_out -=1
        return model_list
                
                

                

In [None]:
h_gen = HierachicalGenerator()
model_list = h_gen.generate(prompt='num_classes:10,input_size:32,param:3M, flops:1100M:(inputs)',dataloader=dataloader, n=5, patiens=1000)